Skip to content

Llama bwd fix opt3#12

Merged
eric-tc-wong merged 10 commits into
mainfrom
llama-bwd-fix-opt3
May 7, 2026
Merged

Llama bwd fix opt3#12
eric-tc-wong merged 10 commits into
mainfrom
llama-bwd-fix-opt3

Conversation

@eric-tc-wong
Copy link
Copy Markdown

This pull request introduces advanced optimizations and configurability for the backward pass in the llama_fwd_ring_bwd_flash_attn.py module, focusing on improving the efficiency and flexibility of distributed attention computation. The main changes add support for custom stride patterns in the backward pass, overlap communication and computation more effectively, and refactor buffer management for better performance and clarity. Reduce scatter with float tensors.

Key improvements and new features:

Backward Pass Optimizations:

  • KEY: float32 buffers for reduce scatter for better precision

  • Added support for bwd_head_first_stride and bwd_head_last_stride arguments to allow the first and last backward communication steps to use smaller strides, reducing the size of un-overlapped all-gather and reduce-scatter operations and improving pipeline efficiency. [1] [2] [3] [4] [5] [6] [7] [8] [9]

  • Refactored the stride pattern logic in the backward pass to support arbitrary stride splits, with assertions and logic to ensure correctness and buffer allocation for each step. This enables flexible partitioning of attention heads across communication steps.

Communication and Buffer Management:

  • Introduced a ReduceScatterHandleManager to manage asynchronous reduce-scatter operations, allowing communication to be overlapped with computation for all interior steps and draining the handles at the correct cadence. [1] [2] [3]

  • Refactored buffer allocation: now uses per-stride-width buffers and double-buffering for the main communication steps, with dedicated buffers for the first and last (potentially smaller) steps. This change improves memory efficiency and pipeline overlap. [1] [2]

API and Documentation Updates:

  • Updated the function signatures and context handling throughout the forward and backward functions to propagate the new stride arguments, ensuring they are available during autograd. [1] [2] [3] [4] [5] [6] [7]

  • Expanded the docstrings and comments to document the new stride options, pipeline mechanics, and buffer management, making the codebase easier to understand and extend. [1] [2] [3]

Project Documentation:

  • Added a new CLAUDE.md file with high-level project guidance, installation and test commands, architectural overview, and design constraints for contributors and users.

eric-tc-wong and others added 7 commits April 29, 2026 21:32
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
…duce_scatter

- Add head_first_stride and head_last_stride parameters to minimize serial collectives
- Implement async double-buffered reduce_scatter pipeline to overlap communication with compute
- Allocate per-width KV and dKV buffers to handle varying stride widths
- Split first/last iterations to reduce initial all_gather and final reduce_scatter sizes
- Plumb new parameters through LlamaFlashAttnFunc and public wrapper functions
- Create CLAUDE.md with architecture and command reference

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…g dedicated CUDA stream

Co-authored-by: Copilot <copilot@github.com>
…ons and remove unnecessary event signaling

Co-authored-by: Copilot <copilot@github.com>
…rations and streamline buffer management in llama_flash_attn_backward
…e and bwd_head_last_stride in llama_flash_attn_backward for clarity

Co-authored-by: Copilot <copilot@github.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves the efficiency and configurability of the Llama backward path in llama_fwd_ring_bwd_flash_attn.py by allowing custom head-stride patterns and overlapping reduce-scatter communication with computation, and introduces a new utility to manage async reduce-scatter handles.

Changes:

  • Add backward stride-splitting controls (bwd_head_first_stride, bwd_head_last_stride) to reduce un-overlapped communication in the first/last pipeline steps.
  • Refactor backward buffer management to allocate per-width buffers and overlap async reduce-scatter with compute via a new ReduceScatterHandleManager.
  • Add CLAUDE.md with repository usage/architecture guidance and common commands.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
ring_flash_attn/utils.py Adds ReduceScatterHandleManager for managing async reduce-scatter operations (including coalesced variant when available).
ring_flash_attn/llama_fwd_ring_bwd_flash_attn.py Implements backward stride-pattern support, buffer refactor, and async reduce-scatter overlap; plumbs new args through autograd contexts and public APIs.
CLAUDE.md Adds contributor-facing project overview, key constraints, and verified commands for install/tests/benchmarks/formatting.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread ring_flash_attn/utils.py
Comment thread ring_flash_attn/utils.py Outdated
Comment thread ring_flash_attn/llama_fwd_ring_bwd_flash_attn.py
Comment on lines +312 to +330
# ---- Build stride pattern (mirrors forward, extended for the tail) ----
n_full = nheads_k // heads_k_stride
if bwd_head_first_stride is not None:
assert 0 < bwd_head_first_stride < heads_k_stride, (
"bwd_head_first_stride must be between 0 and heads_k_stride"
)
first_part = [bwd_head_first_stride, heads_k_stride - bwd_head_first_stride]
n_full -= 1
else:
first_part = []

# Buffer for input to reduce_scatter (Needs to be rank-contiguous)
scatter_input_buffer = torch.empty(
(2, world_size, batch_k, seq_k, heads_k_stride, head_dim),
dtype=torch.float32,
device=k.device,
)
if bwd_head_last_stride is not None:
assert 0 < bwd_head_last_stride < heads_k_stride, (
"bwd_head_last_stride must be between 0 and heads_k_stride"
)
last_part = [heads_k_stride - bwd_head_last_stride, bwd_head_last_stride]
n_full -= 1
else:
last_part = []
eric-tc-wong and others added 3 commits May 6, 2026 16:52
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <copilot@github.com>
@eric-tc-wong eric-tc-wong merged commit 72c9fe5 into main May 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants