-
Couldn't load subscription status.
- Fork 533
Enable SWA with CP for THD input format #2220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
caebcfe to
4987420
Compare
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Signed-off-by: Sudhakar Singh <[email protected]>
…ormerEngine into enable_thd_cp_swa
Signed-off-by: Sudhakar Singh <[email protected]>
…ormerEngine into enable_thd_cp_swa
|
/te-ci pytorch L0 |
|
I see you've been running L0 CI tests. Did you happen to test with all the CP tests in L1? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables Sliding Window Attention (SWA) with Context Parallelism (CP) for the THD (total, heads, dimension) packed sequence input format using All-to-All (A2A) communication. THD differs from standard BSHD/SBHD formats by supporting variable-length sequences tracked via cumulative sequence lengths (cu_seqlens). The implementation adds two new sequence reordering functions (reorder_seq_chunks_before_a2a_after_attn_thd and reorder_seq_chunks_after_a2a_before_attn_thd) that handle the special chunking logic required for THD's variable-length sequences during A2A communication. The main flash_attn_a2a_communicate function now branches on qkv_format to apply THD-specific reordering paths. Additionally, mask type restrictions are relaxed to allow padding masks with THD, which is necessary for handling variable-length sequences in SWA. This extends TransformerEngine's CP capabilities to production scenarios requiring packed sequences with sliding window attention for long-context models.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 3/5 | Implements THD-specific sequence reordering functions and adds format branching in A2A communication logic |
| tests/pytorch/attention/test_attention_with_cp.py | 1/5 | Adds test configurations for THD+CP+SWA but references undefined config cp_1_4 causing KeyError |
Confidence score: 2/5
- This PR introduces functional changes but contains test configuration errors that will cause immediate failures
- Score reduced due to: (1) undefined test config
cp_1_4that will cause KeyError, (2) inconsistent skip logic between flash and fused attention tests for THD+A2A support, (3) side-effect-prone runtime config mutations in tests that may affect test isolation - Pay close attention to tests/pytorch/attention/test_attention_with_cp.py line 76 where
cp_1_4must be defined inmodel_configs_flash_attnbefore tests can run successfully
Sequence Diagram
sequenceDiagram
participant User
participant Test as test_attention_with_cp.py
participant CP as context_parallel.py
participant FusedAttn as fused_attn
participant A2A as AllToAll Communication
participant Attention as Attention Kernel
User->>Test: Run test with SWA+THD+CP (using A2A)
Test->>Test: Check compatibility filters
Note over Test: Skip if:<br/>- cp_comm_type != 'a2a'<br/>- qkv_format != 'thd'<br/>- window_size not supported
Test->>CP: attn_forward_func_with_cp()
Note over CP: cp_comm_type = 'a2a'<br/>qkv_format = 'thd'<br/>window_size set
CP->>CP: Validate configuration
Note over CP: Assert sliding_window_attn<br/>compatible with cp_comm_type='a2a'
CP->>A2A: flash_attn_a2a_communicate()
Note over A2A: Reorder sequence chunks<br/>for A2A before attention
A2A->>A2A: reorder_seq_chunks_after_a2a_before_attn_thd()
Note over A2A: Reorder THD format sequences<br/>using cu_seqlens and chunk_ids
A2A-->>CP: Reordered Q, K, V tensors
CP->>FusedAttn: fused_attn_fwd()
Note over FusedAttn: Forward pass with:<br/>- window_size applied<br/>- THD format<br/>- CP enabled
FusedAttn->>Attention: Execute attention kernel
Attention-->>FusedAttn: Attention output
FusedAttn-->>CP: out, aux_ctx_tensors, max_logit
CP->>A2A: flash_attn_a2a_communicate()
Note over A2A: Reorder sequence chunks<br/>for A2A after attention
A2A->>A2A: reorder_seq_chunks_before_a2a_after_attn_thd()
Note over A2A: Restore original sequence order<br/>for THD format
A2A-->>CP: Reordered output
CP-->>Test: Final output tensor
Test-->>User: Test passes/fails
Additional Comments (1)
-
tests/pytorch/attention/test_attention_with_cp.py, line 103-104 (link)logic: Inconsistency: line 103-104 skip THD with A2A entirely, but lines 227-234 allow THD with
a2a(only blockinga2a+p2p). This creates conflicting skip conditions between flash and fused attention tests.
2 files reviewed, 3 comments
| cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] | ||
| if test_essential: | ||
| configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] | ||
| configs = ["cp_1_0", "cp_1_4", "cp_2_0", "cp_2_1", "cp_3_2", "cp_3_3"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: cp_1_4 is not defined in model_configs_flash_attn (line 36-54). This will cause a KeyError when the test suite runs.
| if qkv_format == "thd": | ||
| print(f"config.attn_mask_type: {config.attn_mask_type}") | ||
| if "causal" in config.attn_mask_type: | ||
| config.attn_mask_type = "padding_causal" | ||
| else: | ||
| config.attn_mask_type = "padding" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Modifying the config object during test execution creates non-obvious side effects. The mask type change should be documented or handled differently to avoid confusion. Is this intended to be a global modification that affects subsequent tests, or should this be isolated per test run?
Signed-off-by: Sudhakar Singh <[email protected]>
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest changes address previously identified issues by removing cp_1_4 from the FlashAttention test configurations (which was causing a KeyError since it was never defined in model_configs_flash_attn) and by adding better documentation and context around the mask type modification for THD format. The THD format changes (lines 289-294) now transform mask types to their padding-aware equivalents ("padding" or "padding_causal") when processing packed sequences in THD format. This is necessary because THD uses variable-length packed sequences that require explicit padding information for correct attention computation. The changes integrate with the rest of the codebase by aligning test configurations with actual implementation capabilities for Sliding Window Attention with Context Parallelism on THD inputs using All-to-All communication, completing a feature that was previously incomplete.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/attention/test_attention_with_cp.py | 4/5 | Removed cp_1_4 from FlashAttention test configs (line 76) and updated THD format mask type handling (lines 289-294) to use padding-aware variants |
Confidence score: 4/5
- This PR is safe to merge with minimal risk, completing the THD+CP+SWA feature enablement
- Score reflects the correction of the previous KeyError issue and proper handling of THD format requirements; one point deducted due to limited context on whether the mask type mutation has broader testing implications
- Pay close attention to lines 289-294 to ensure the config mutation pattern is acceptable for the test framework and doesn't create unexpected side effects in parallel test execution
1 file reviewed, no comments
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest changes introduce additional test configurations to verify the newly enabled Sliding Window Attention (SWA) with Context Parallelism (CP) for THD (packed variable-length sequence) format using All-to-All (A2A) communication. The core implementation includes two new reordering functions (reorder_seq_chunks_before_a2a_after_attn_thd and reorder_seq_chunks_after_a2a_before_attn_thd) that handle the DualChunking pattern for THD format by building index arrays based on cumulative sequence lengths. The changes modify flash_attn_a2a_communicate to branch on qkv_format and handle THD-specific reshaping. Test skip conditions are updated to allow THD with basic A2A while maintaining restrictions for the hierarchical A2A+P2P variant, and the attention mask type is dynamically converted to padding-based masks when THD format is used (e.g., "causal" becomes "padding_causal"). These changes integrate with the existing context parallelism infrastructure in transformer_engine/pytorch/attention by extending the communication strategy to support variable-length packed sequences.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/attention/test_attention_with_cp.py | 3/5 | Adds test configurations cp_2_0 and cp_1_4, updates skip conditions to allow THD+A2A but block THD+A2A+P2P, and dynamically modifies attn_mask_type for THD format; contains a debug print statement that should be removed |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 3/5 | Implements THD reordering functions with complex index calculations, modifies flash_attn_a2a_communicate to branch on qkv_format, removes THD blocking assertion, and passes qkv_format/cu_seqlens through call stack |
Confidence score: 3/5
- This PR introduces complex new functionality with moderate risk due to intricate index manipulation logic for variable-length sequences that may fail in edge cases
- Score reflects two main concerns: (1) the new THD reordering functions build large index tensors on CPU before GPU transfer which could cause performance issues, and (2) the complex nested index calculations are error-prone for edge cases like very small sequences or uneven divisibility by
cp_size - Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py lines 262-395 where the THD reordering logic performs nested loops and complex arithmetic that may not handle all sequence length distributions correctly, and verify test coverage for edge cases
2 files reviewed, 3 comments
| if qkv_format == "thd": | ||
| print(f"config.attn_mask_type: {config.attn_mask_type}") | ||
| if "causal" in config.attn_mask_type: | ||
| config.attn_mask_type = "padding_causal" | ||
| else: | ||
| config.attn_mask_type = "padding" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Remove the debug print statement at line 290 before merging to production
| def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0): | ||
| """ | ||
| Reorder sequence chunks for A2A communication that happens after attention | ||
| compute. | ||
| Args: | ||
| x: The input tensor to be reordered. | ||
| cu_seqlens: The cumulative sequence lengths of the input tensor. | ||
| cp_size: The number of ranks participating in context parallelism. | ||
| seq_dim: The dimension in which to reorder. | ||
| Returns: | ||
| The reordered tensor. | ||
| Example: | ||
| x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., | ||
| 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., | ||
| 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] | ||
| cu_seqlens: [ 0, 2, 4, 6, 10] | ||
| cp_size: 4 | ||
| Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., | ||
| 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., | ||
| 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] | ||
| This logic is similar to how the DualChunking is done to split the sequence | ||
| for each rank. Here, the indices of sequence chunks for all those ranks | ||
| are concatenated together. So the returned tensor ends up looking like as if | ||
| the chunks from all the ranks are concatenated together. | ||
| e.g. [ | ||
| 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 | ||
| 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 | ||
| 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 | ||
| 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 | ||
| ] | ||
| """ | ||
| total_slices_of_any_sequence = 2 * cp_size | ||
| slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence | ||
|
|
||
| indices = [ | ||
| ( | ||
| # 1st segment | ||
| torch.arange( | ||
| seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size) | ||
| ), | ||
| # 2nd segment | ||
| torch.arange( | ||
| seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), | ||
| seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), | ||
| ), | ||
| ) | ||
| for cp_rank in range(cp_size) | ||
| for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) | ||
| ] | ||
|
|
||
| # flatten the list of tuples to a list | ||
| indices = list(itertools.chain(*indices)) | ||
| indices = torch.cat(indices).to(device=cu_seqlens.device) | ||
| return x.index_select(seq_dim, indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Potential performance issue: building indices on CPU then transferring to GPU. Consider generating indices directly on device for large sequences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@greptile-apps, could you also suggest the code to do so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the function, the performance issue is that on line 321, the indices tensor is being created on CPU (default device) and then transferred to GPU with .to(device=cu_seqlens.device).
Here's the fix - generate the indices directly on the target device:
| def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0): | |
| """ | |
| Reorder sequence chunks for A2A communication that happens after attention | |
| compute. | |
| Args: | |
| x: The input tensor to be reordered. | |
| cu_seqlens: The cumulative sequence lengths of the input tensor. | |
| cp_size: The number of ranks participating in context parallelism. | |
| seq_dim: The dimension in which to reorder. | |
| Returns: | |
| The reordered tensor. | |
| Example: | |
| x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., | |
| 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., | |
| 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] | |
| cu_seqlens: [ 0, 2, 4, 6, 10] | |
| cp_size: 4 | |
| Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., | |
| 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., | |
| 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] | |
| This logic is similar to how the DualChunking is done to split the sequence | |
| for each rank. Here, the indices of sequence chunks for all those ranks | |
| are concatenated together. So the returned tensor ends up looking like as if | |
| the chunks from all the ranks are concatenated together. | |
| e.g. [ | |
| 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 | |
| 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 | |
| 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 | |
| 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 | |
| ] | |
| """ | |
| total_slices_of_any_sequence = 2 * cp_size | |
| slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence | |
| indices = [ | |
| ( | |
| # 1st segment | |
| torch.arange( | |
| seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size) | |
| ), | |
| # 2nd segment | |
| torch.arange( | |
| seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), | |
| seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), | |
| ), | |
| ) | |
| for cp_rank in range(cp_size) | |
| for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) | |
| ] | |
| # flatten the list of tuples to a list | |
| indices = list(itertools.chain(*indices)) | |
| indices = torch.cat(indices).to(device=cu_seqlens.device) | |
| return x.index_select(seq_dim, indices) | |
| def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0): | |
| """ | |
| Reorder sequence chunks for A2A communication that happens after attention | |
| compute. | |
| Args: | |
| x: The input tensor to be reordered. | |
| cu_seqlens: The cumulative sequence lengths of the input tensor. | |
| cp_size: The number of ranks participating in context parallelism. | |
| seq_dim: The dimension in which to reorder. | |
| Returns: | |
| The reordered tensor. | |
| Example: | |
| x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., | |
| 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., | |
| 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] | |
| cu_seqlens: [ 0, 2, 4, 6, 10] | |
| cp_size: 4 | |
| Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., | |
| 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., | |
| 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] | |
| This logic is similar to how the DualChunking is done to split the sequence | |
| for each rank. Here, the indices of sequence chunks for all those ranks | |
| are concatenated together. So the returned tensor ends up looking like as if | |
| the chunks from all the ranks are concatenated together. | |
| e.g. [ | |
| 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 | |
| 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 | |
| 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 | |
| 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 | |
| ] | |
| """ | |
| total_slices_of_any_sequence = 2 * cp_size | |
| slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence | |
| indices = [ | |
| ( | |
| # 1st segment | |
| torch.arange( | |
| seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size), | |
| device=cu_seqlens.device | |
| ), | |
| # 2nd segment | |
| torch.arange( | |
| seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), | |
| seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), | |
| device=cu_seqlens.device | |
| ), | |
| ) | |
| for cp_rank in range(cp_size) | |
| for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) | |
| ] | |
| # flatten the list of tuples to a list | |
| indices = list(itertools.chain(*indices)) | |
| indices = torch.cat(indices) | |
| return x.index_select(seq_dim, indices) |
The key changes:
- Lines 306-308: Added
device=cu_seqlens.deviceto the firsttorch.arangecall - Lines 310-313: Added
device=cu_seqlens.deviceto the secondtorch.arangecall - Line 321: Removed the
.to(device=cu_seqlens.device)call since tensors are already on the correct device
This eliminates the CPU→GPU transfer overhead for large sequences.
| def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0): | ||
| """ | ||
| Reorder sequence chunks for A2A communication that happens before attention | ||
| compute. | ||
| Args: | ||
| x: The input tensor to be reordered. | ||
| cu_seqlens: The cumulative sequence lengths of the input tensor. | ||
| seq_chunk_ids: The sequence chunk ids of the input `x` which is to be reordered. | ||
| cp_size: The number of ranks participating in context parallelism. | ||
| seq_dim: The dimension in which to reorder. | ||
| Returns: | ||
| The reordered tensor. | ||
| Example: | ||
| x: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., | ||
| 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., | ||
| 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] | ||
| cu_seqlens: [ 0, 8, 16, 24, 40] | ||
| seq_chunk_ids: [ 0, 2, 4, 6, 7, 5, 3, 1] | ||
| cp_size: 4 | ||
| Returns: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., | ||
| 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., | ||
| 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] | ||
| Note that the input sequences (x) are arrangd after A2A communication as if DualChunked | ||
| chunks on all the ranks are concatenated together in the `seq_dim`. | ||
| e.g. [ | ||
| 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 | ||
| 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 | ||
| 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 | ||
| 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 | ||
| ] | ||
| Then the logic to serialize the sequences is: | ||
| 1. For every sequence segment on any rank (denoted by `start` and `end`): | ||
| 1a. For every chunk (in `chunk_id` and the total of those are twice as many as the number of CP ranks) : | ||
| 1aa. The first `cp_size` number of chunks form the first half of the whole sequence. Get those indices. | ||
| 1ab. The second `cp_size` number of chunks form the second half of the whole sequence. Get those indices. | ||
| 1b. Concatenate the indices of the first half and the second half. | ||
| 2. Reorder the entire input tensor by those indices. | ||
| """ | ||
|
|
||
| max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size | ||
| cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size | ||
|
|
||
| # Go through all the sequence segments (the sizes should be the same from all the ranks) | ||
| indices = [ | ||
| torch.arange( | ||
| # Calculate 'left' boundary | ||
| ( | ||
| start + max_cum_seqlen_per_cp_rank * (chunk_id // 2) | ||
| if loc < cp_size | ||
| else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) | ||
| ), | ||
| # Calculate 'right' boundary | ||
| ( | ||
| (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) | ||
| if loc < cp_size | ||
| else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2) | ||
| ), | ||
| ) | ||
| for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:]) | ||
| for loc, chunk_id in enumerate(seq_chunk_ids) | ||
| ] | ||
|
|
||
| indices = torch.cat(indices).to(device=cu_seqlens.device) | ||
| return x.index_select(seq_dim, indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Similar to the function above, this builds indices on CPU. For consistency and performance with large cu_seqlens, generate indices directly on device.
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The recent changes address all five previously identified issues: two performance improvements eliminate redundant CPU-to-GPU transfers by generating indices directly on the target device in context_parallel.py, one cleanup removes a debug print statement from the test file, one clarification keeps the intentional mask-type mutation in the test configuration (now understood as necessary for THD format handling), and one fix resolves a missing test configuration key. These changes are localized optimizations and corrections to the SWA+THD+CP feature implementation, streamlining device placement in index generation and cleaning up test artifacts.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 5/5 | Optimized device placement for index generation in THD format A2A communication—tensors now created directly on device |
| tests/pytorch/attention/test_attention_with_cp.py | 5/5 | Removed debug print statement and clarified THD mask type handling; test skip conditions remain unchanged |
Confidence score: 5/5
- This PR is safe to merge with minimal risk
- All previously flagged issues have been addressed: performance improvements eliminate CPU-to-GPU transfers, debug print removed, mask type handling clarified, and test configuration corrected
- No files require special attention; changes are straightforward optimizations and cleanup
Additional Comments (1)
-
tests/pytorch/attention/test_attention_with_cp.py, line 104 (link)logic: This skip condition is now outdated (contradicts line 103). The PR enables THD format support for A2A.
2 files reviewed, 2 comments
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip( | ||
| "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" | ||
| " yet!" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The skip at line 229 contradicts the PR's stated goal of enabling THD+CP with A2A. This should be removed or updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't, since A2A+P2P is different than A2A only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right – A2A and A2A+P2P are different communication patterns. The skip at line 229 is specifically for all_gather, not for A2A, so there's no contradiction with the PR's goal of enabling THD+CP with A2A.
Looking at the code:
- Line 229: Skips when
cp_comm_type == "all_gather" - Line 230-234: Skips when
cp_comm_type == "a2a+p2p"
These are two separate conditions. The PR enables THD with A2A (plain "a2a"), which is not blocked by either of these skips. So the logic is actually correct and consistent with the PR's objective.
My apologies for the confusion in my original comment!
Description
Sliding Window Attention with CP for THD format is enabled with A2A communication.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR: