Llama bwd fix opt3#12
Merged
Merged
Conversation
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
…duce_scatter - Add head_first_stride and head_last_stride parameters to minimize serial collectives - Implement async double-buffered reduce_scatter pipeline to overlap communication with compute - Allocate per-width KV and dKV buffers to handle varying stride widths - Split first/last iterations to reduce initial all_gather and final reduce_scatter sizes - Plumb new parameters through LlamaFlashAttnFunc and public wrapper functions - Create CLAUDE.md with architecture and command reference Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…g dedicated CUDA stream Co-authored-by: Copilot <copilot@github.com>
…ons and remove unnecessary event signaling Co-authored-by: Copilot <copilot@github.com>
…rations and streamline buffer management in llama_flash_attn_backward
…e and bwd_head_last_stride in llama_flash_attn_backward for clarity Co-authored-by: Copilot <copilot@github.com>
There was a problem hiding this comment.
Pull request overview
This PR improves the efficiency and configurability of the Llama backward path in llama_fwd_ring_bwd_flash_attn.py by allowing custom head-stride patterns and overlapping reduce-scatter communication with computation, and introduces a new utility to manage async reduce-scatter handles.
Changes:
- Add backward stride-splitting controls (
bwd_head_first_stride,bwd_head_last_stride) to reduce un-overlapped communication in the first/last pipeline steps. - Refactor backward buffer management to allocate per-width buffers and overlap async reduce-scatter with compute via a new
ReduceScatterHandleManager. - Add
CLAUDE.mdwith repository usage/architecture guidance and common commands.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
ring_flash_attn/utils.py |
Adds ReduceScatterHandleManager for managing async reduce-scatter operations (including coalesced variant when available). |
ring_flash_attn/llama_fwd_ring_bwd_flash_attn.py |
Implements backward stride-pattern support, buffer refactor, and async reduce-scatter overlap; plumbs new args through autograd contexts and public APIs. |
CLAUDE.md |
Adds contributor-facing project overview, key constraints, and verified commands for install/tests/benchmarks/formatting. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+312
to
+330
| # ---- Build stride pattern (mirrors forward, extended for the tail) ---- | ||
| n_full = nheads_k // heads_k_stride | ||
| if bwd_head_first_stride is not None: | ||
| assert 0 < bwd_head_first_stride < heads_k_stride, ( | ||
| "bwd_head_first_stride must be between 0 and heads_k_stride" | ||
| ) | ||
| first_part = [bwd_head_first_stride, heads_k_stride - bwd_head_first_stride] | ||
| n_full -= 1 | ||
| else: | ||
| first_part = [] | ||
|
|
||
| # Buffer for input to reduce_scatter (Needs to be rank-contiguous) | ||
| scatter_input_buffer = torch.empty( | ||
| (2, world_size, batch_k, seq_k, heads_k_stride, head_dim), | ||
| dtype=torch.float32, | ||
| device=k.device, | ||
| ) | ||
| if bwd_head_last_stride is not None: | ||
| assert 0 < bwd_head_last_stride < heads_k_stride, ( | ||
| "bwd_head_last_stride must be between 0 and heads_k_stride" | ||
| ) | ||
| last_part = [heads_k_stride - bwd_head_last_stride, bwd_head_last_stride] | ||
| n_full -= 1 | ||
| else: | ||
| last_part = [] |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <copilot@github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces advanced optimizations and configurability for the backward pass in the
llama_fwd_ring_bwd_flash_attn.pymodule, focusing on improving the efficiency and flexibility of distributed attention computation. The main changes add support for custom stride patterns in the backward pass, overlap communication and computation more effectively, and refactor buffer management for better performance and clarity. Reduce scatter with float tensors.Key improvements and new features:
Backward Pass Optimizations:
KEY: float32 buffers for reduce scatter for better precision
Added support for
bwd_head_first_strideandbwd_head_last_stridearguments to allow the first and last backward communication steps to use smaller strides, reducing the size of un-overlapped all-gather and reduce-scatter operations and improving pipeline efficiency. [1] [2] [3] [4] [5] [6] [7] [8] [9]Refactored the stride pattern logic in the backward pass to support arbitrary stride splits, with assertions and logic to ensure correctness and buffer allocation for each step. This enables flexible partitioning of attention heads across communication steps.
Communication and Buffer Management:
Introduced a
ReduceScatterHandleManagerto manage asynchronous reduce-scatter operations, allowing communication to be overlapped with computation for all interior steps and draining the handles at the correct cadence. [1] [2] [3]Refactored buffer allocation: now uses per-stride-width buffers and double-buffering for the main communication steps, with dedicated buffers for the first and last (potentially smaller) steps. This change improves memory efficiency and pipeline overlap. [1] [2]
API and Documentation Updates:
Updated the function signatures and context handling throughout the forward and backward functions to propagate the new stride arguments, ensuring they are available during autograd. [1] [2] [3] [4] [5] [6] [7]
Expanded the docstrings and comments to document the new stride options, pipeline mechanics, and buffer management, making the codebase easier to understand and extend. [1] [2] [3]
Project Documentation:
CLAUDE.mdfile with high-level project guidance, installation and test commands, architectural overview, and design constraints for contributors and users.