Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

Sliding Window Attention with CP for THD format is enabled with A2A communication.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • SWA+THD+CP (using A2A)
  • Filters that allow such config

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@cyanguwa
Copy link
Collaborator

I see you've been running L0 CI tests. Did you happen to test with all the CP tests in L1? Thanks.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables Sliding Window Attention (SWA) with Context Parallelism (CP) for the THD (total, heads, dimension) packed sequence input format using All-to-All (A2A) communication. THD differs from standard BSHD/SBHD formats by supporting variable-length sequences tracked via cumulative sequence lengths (cu_seqlens). The implementation adds two new sequence reordering functions (reorder_seq_chunks_before_a2a_after_attn_thd and reorder_seq_chunks_after_a2a_before_attn_thd) that handle the special chunking logic required for THD's variable-length sequences during A2A communication. The main flash_attn_a2a_communicate function now branches on qkv_format to apply THD-specific reordering paths. Additionally, mask type restrictions are relaxed to allow padding masks with THD, which is necessary for handling variable-length sequences in SWA. This extends TransformerEngine's CP capabilities to production scenarios requiring packed sequences with sliding window attention for long-context models.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 3/5 Implements THD-specific sequence reordering functions and adds format branching in A2A communication logic
tests/pytorch/attention/test_attention_with_cp.py 1/5 Adds test configurations for THD+CP+SWA but references undefined config cp_1_4 causing KeyError

Confidence score: 2/5

  • This PR introduces functional changes but contains test configuration errors that will cause immediate failures
  • Score reduced due to: (1) undefined test config cp_1_4 that will cause KeyError, (2) inconsistent skip logic between flash and fused attention tests for THD+A2A support, (3) side-effect-prone runtime config mutations in tests that may affect test isolation
  • Pay close attention to tests/pytorch/attention/test_attention_with_cp.py line 76 where cp_1_4 must be defined in model_configs_flash_attn before tests can run successfully

Sequence Diagram

sequenceDiagram
    participant User
    participant Test as test_attention_with_cp.py
    participant CP as context_parallel.py
    participant FusedAttn as fused_attn
    participant A2A as AllToAll Communication
    participant Attention as Attention Kernel

    User->>Test: Run test with SWA+THD+CP (using A2A)
    Test->>Test: Check compatibility filters
    Note over Test: Skip if:<br/>- cp_comm_type != 'a2a'<br/>- qkv_format != 'thd'<br/>- window_size not supported
    
    Test->>CP: attn_forward_func_with_cp()
    Note over CP: cp_comm_type = 'a2a'<br/>qkv_format = 'thd'<br/>window_size set
    
    CP->>CP: Validate configuration
    Note over CP: Assert sliding_window_attn<br/>compatible with cp_comm_type='a2a'
    
    CP->>A2A: flash_attn_a2a_communicate()
    Note over A2A: Reorder sequence chunks<br/>for A2A before attention
    
    A2A->>A2A: reorder_seq_chunks_after_a2a_before_attn_thd()
    Note over A2A: Reorder THD format sequences<br/>using cu_seqlens and chunk_ids
    
    A2A-->>CP: Reordered Q, K, V tensors
    
    CP->>FusedAttn: fused_attn_fwd()
    Note over FusedAttn: Forward pass with:<br/>- window_size applied<br/>- THD format<br/>- CP enabled
    
    FusedAttn->>Attention: Execute attention kernel
    Attention-->>FusedAttn: Attention output
    
    FusedAttn-->>CP: out, aux_ctx_tensors, max_logit
    
    CP->>A2A: flash_attn_a2a_communicate()
    Note over A2A: Reorder sequence chunks<br/>for A2A after attention
    
    A2A->>A2A: reorder_seq_chunks_before_a2a_after_attn_thd()
    Note over A2A: Restore original sequence order<br/>for THD format
    
    A2A-->>CP: Reordered output
    
    CP-->>Test: Final output tensor
    Test-->>User: Test passes/fails
Loading

Additional Comments (1)

  1. tests/pytorch/attention/test_attention_with_cp.py, line 103-104 (link)

    logic: Inconsistency: line 103-104 skip THD with A2A entirely, but lines 227-234 allow THD with a2a (only blocking a2a+p2p). This creates conflicting skip conditions between flash and fused attention tests.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
configs = ["cp_1_0", "cp_1_4", "cp_2_0", "cp_2_1", "cp_3_2", "cp_3_3"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: cp_1_4 is not defined in model_configs_flash_attn (line 36-54). This will cause a KeyError when the test suite runs.

Comment on lines 289 to 294
if qkv_format == "thd":
print(f"config.attn_mask_type: {config.attn_mask_type}")
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Modifying the config object during test execution creates non-obvious side effects. The mask type change should be documented or handled differently to avoid confusion. Is this intended to be a global modification that affects subsequent tests, or should this be isolated per test run?

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest changes address previously identified issues by removing cp_1_4 from the FlashAttention test configurations (which was causing a KeyError since it was never defined in model_configs_flash_attn) and by adding better documentation and context around the mask type modification for THD format. The THD format changes (lines 289-294) now transform mask types to their padding-aware equivalents ("padding" or "padding_causal") when processing packed sequences in THD format. This is necessary because THD uses variable-length packed sequences that require explicit padding information for correct attention computation. The changes integrate with the rest of the codebase by aligning test configurations with actual implementation capabilities for Sliding Window Attention with Context Parallelism on THD inputs using All-to-All communication, completing a feature that was previously incomplete.

Important Files Changed

Filename Score Overview
tests/pytorch/attention/test_attention_with_cp.py 4/5 Removed cp_1_4 from FlashAttention test configs (line 76) and updated THD format mask type handling (lines 289-294) to use padding-aware variants

Confidence score: 4/5

  • This PR is safe to merge with minimal risk, completing the THD+CP+SWA feature enablement
  • Score reflects the correction of the previous KeyError issue and proper handling of THD format requirements; one point deducted due to limited context on whether the mask type mutation has broader testing implications
  • Pay close attention to lines 289-294 to ensure the config mutation pattern is acceptable for the test framework and doesn't create unexpected side effects in parallel test execution

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest changes introduce additional test configurations to verify the newly enabled Sliding Window Attention (SWA) with Context Parallelism (CP) for THD (packed variable-length sequence) format using All-to-All (A2A) communication. The core implementation includes two new reordering functions (reorder_seq_chunks_before_a2a_after_attn_thd and reorder_seq_chunks_after_a2a_before_attn_thd) that handle the DualChunking pattern for THD format by building index arrays based on cumulative sequence lengths. The changes modify flash_attn_a2a_communicate to branch on qkv_format and handle THD-specific reshaping. Test skip conditions are updated to allow THD with basic A2A while maintaining restrictions for the hierarchical A2A+P2P variant, and the attention mask type is dynamically converted to padding-based masks when THD format is used (e.g., "causal" becomes "padding_causal"). These changes integrate with the existing context parallelism infrastructure in transformer_engine/pytorch/attention by extending the communication strategy to support variable-length packed sequences.

Important Files Changed

Filename Score Overview
tests/pytorch/attention/test_attention_with_cp.py 3/5 Adds test configurations cp_2_0 and cp_1_4, updates skip conditions to allow THD+A2A but block THD+A2A+P2P, and dynamically modifies attn_mask_type for THD format; contains a debug print statement that should be removed
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 3/5 Implements THD reordering functions with complex index calculations, modifies flash_attn_a2a_communicate to branch on qkv_format, removes THD blocking assertion, and passes qkv_format/cu_seqlens through call stack

Confidence score: 3/5

  • This PR introduces complex new functionality with moderate risk due to intricate index manipulation logic for variable-length sequences that may fail in edge cases
  • Score reflects two main concerns: (1) the new THD reordering functions build large index tensors on CPU before GPU transfer which could cause performance issues, and (2) the complex nested index calculations are error-prone for edge cases like very small sequences or uneven divisibility by cp_size
  • Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py lines 262-395 where the THD reordering logic performs nested loops and complex arithmetic that may not handle all sequence length distributions correctly, and verify test coverage for edge cases

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 289 to 294
if qkv_format == "thd":
print(f"config.attn_mask_type: {config.attn_mask_type}")
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Remove the debug print statement at line 290 before merging to production

Comment on lines 262 to 322
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens after attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
cu_seqlens: [ 0, 2, 4, 6, 10]
cp_size: 4
Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
This logic is similar to how the DualChunking is done to split the sequence
for each rank. Here, the indices of sequence chunks for all those ranks
are concatenated together. So the returned tensor ends up looking like as if
the chunks from all the ranks are concatenated together.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
"""
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence

indices = [
(
# 1st segment
torch.arange(
seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size)
),
# 2nd segment
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]

# flatten the list of tuples to a list
indices = list(itertools.chain(*indices))
indices = torch.cat(indices).to(device=cu_seqlens.device)
return x.index_select(seq_dim, indices)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Potential performance issue: building indices on CPU then transferring to GPU. Consider generating indices directly on device for large sequences.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptile-apps, could you also suggest the code to do so

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the function, the performance issue is that on line 321, the indices tensor is being created on CPU (default device) and then transferred to GPU with .to(device=cu_seqlens.device).

Here's the fix - generate the indices directly on the target device:

Suggested change
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens after attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
cu_seqlens: [ 0, 2, 4, 6, 10]
cp_size: 4
Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
This logic is similar to how the DualChunking is done to split the sequence
for each rank. Here, the indices of sequence chunks for all those ranks
are concatenated together. So the returned tensor ends up looking like as if
the chunks from all the ranks are concatenated together.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
"""
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence
indices = [
(
# 1st segment
torch.arange(
seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size)
),
# 2nd segment
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]
# flatten the list of tuples to a list
indices = list(itertools.chain(*indices))
indices = torch.cat(indices).to(device=cu_seqlens.device)
return x.index_select(seq_dim, indices)
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens after attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
cu_seqlens: [ 0, 2, 4, 6, 10]
cp_size: 4
Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
This logic is similar to how the DualChunking is done to split the sequence
for each rank. Here, the indices of sequence chunks for all those ranks
are concatenated together. So the returned tensor ends up looking like as if
the chunks from all the ranks are concatenated together.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
"""
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence
indices = [
(
# 1st segment
torch.arange(
seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size),
device=cu_seqlens.device
),
# 2nd segment
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=cu_seqlens.device
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]
# flatten the list of tuples to a list
indices = list(itertools.chain(*indices))
indices = torch.cat(indices)
return x.index_select(seq_dim, indices)

The key changes:

  1. Lines 306-308: Added device=cu_seqlens.device to the first torch.arange call
  2. Lines 310-313: Added device=cu_seqlens.device to the second torch.arange call
  3. Line 321: Removed the .to(device=cu_seqlens.device) call since tensors are already on the correct device

This eliminates the CPU→GPU transfer overhead for large sequences.

Comment on lines 325 to 395
def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens before attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
seq_chunk_ids: The sequence chunk ids of the input `x` which is to be reordered.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
cu_seqlens: [ 0, 8, 16, 24, 40]
seq_chunk_ids: [ 0, 2, 4, 6, 7, 5, 3, 1]
cp_size: 4
Returns: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
Note that the input sequences (x) are arrangd after A2A communication as if DualChunked
chunks on all the ranks are concatenated together in the `seq_dim`.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
Then the logic to serialize the sequences is:
1. For every sequence segment on any rank (denoted by `start` and `end`):
1a. For every chunk (in `chunk_id` and the total of those are twice as many as the number of CP ranks) :
1aa. The first `cp_size` number of chunks form the first half of the whole sequence. Get those indices.
1ab. The second `cp_size` number of chunks form the second half of the whole sequence. Get those indices.
1b. Concatenate the indices of the first half and the second half.
2. Reorder the entire input tensor by those indices.
"""

max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size
cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size

# Go through all the sequence segments (the sizes should be the same from all the ranks)
indices = [
torch.arange(
# Calculate 'left' boundary
(
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
# Calculate 'right' boundary
(
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
)
for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:])
for loc, chunk_id in enumerate(seq_chunk_ids)
]

indices = torch.cat(indices).to(device=cu_seqlens.device)
return x.index_select(seq_dim, indices)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Similar to the function above, this builds indices on CPU. For consistency and performance with large cu_seqlens, generate indices directly on device.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The recent changes address all five previously identified issues: two performance improvements eliminate redundant CPU-to-GPU transfers by generating indices directly on the target device in context_parallel.py, one cleanup removes a debug print statement from the test file, one clarification keeps the intentional mask-type mutation in the test configuration (now understood as necessary for THD format handling), and one fix resolves a missing test configuration key. These changes are localized optimizations and corrections to the SWA+THD+CP feature implementation, streamlining device placement in index generation and cleaning up test artifacts.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 5/5 Optimized device placement for index generation in THD format A2A communication—tensors now created directly on device
tests/pytorch/attention/test_attention_with_cp.py 5/5 Removed debug print statement and clarified THD mask type handling; test skip conditions remain unchanged

Confidence score: 5/5

  • This PR is safe to merge with minimal risk
  • All previously flagged issues have been addressed: performance improvements eliminate CPU-to-GPU transfers, debug print removed, mask type handling clarified, and test configuration corrected
  • No files require special attention; changes are straightforward optimizations and cleanup

Additional Comments (1)

  1. tests/pytorch/attention/test_attention_with_cp.py, line 104 (link)

    logic: This skip condition is now outdated (contradicts line 103). The PR enables THD format support for A2A.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +229 to +234
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The skip at line 229 contradicts the PR's stated goal of enabling THD+CP with A2A. This should be removed or updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, since A2A+P2P is different than A2A only

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right – A2A and A2A+P2P are different communication patterns. The skip at line 229 is specifically for all_gather, not for A2A, so there's no contradiction with the PR's goal of enabling THD+CP with A2A.

Looking at the code:

  • Line 229: Skips when cp_comm_type == "all_gather"
  • Line 230-234: Skips when cp_comm_type == "a2a+p2p"

These are two separate conditions. The PR enables THD with A2A (plain "a2a"), which is not blocked by either of these skips. So the logic is actually correct and consistent with the PR's objective.

My apologies for the confusion in my original comment!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants