-
Notifications
You must be signed in to change notification settings - Fork 168
Refactor set_multi_step_attn_mask for arbitrary step #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor set_multi_step_attn_mask for arbitrary step #348
Conversation
WalkthroughRefactors Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Eagle as MegatronEagle
rect rgba(220,235,245,0.5)
note over Eagle: set_multi_step_attention_mask(attn_mask, step)
Caller->>Eagle: set_multi_step_attention_mask(attn_mask, step)
loop iter = 2 .. step
Eagle->>Eagle: build zero_mask, mask_0, mask_1
Eagle->>Eagle: concat per-iteration masks into attn_mask
end
Eagle-->>Caller: updated attn_mask
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
317-318
: Optional: vectorize the diagonal write to avoid a Python loop.Minor perf tidy-up; reduces Python overhead for large s.
- for i in range(step_idx - 1, s - 1): - mask_1[:, :, i, i] = False + if step_idx - 1 < s - 1: + idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327
: Generalization approach looks good.The iterative block-growth builds an (steps)×(steps) mask and removes the hardcoded branches. Nice cleanup.
Please re-run the sandbox regression with:
- CUDA tensors (to confirm device fix), and
- very short sequences (e.g., s ∈ {1,2,3}) with step ∈ {2,3,4} to ensure no out-of-bounds and correct shapes.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #348 +/- ##
==========================================
- Coverage 73.84% 73.83% -0.01%
==========================================
Files 172 172
Lines 17453 17453
==========================================
- Hits 12888 12887 -1
- Misses 4565 4566 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327
: Fix device/dtype mismatch and short‑sequence OOB; don’t shadow built‑initer
.
- zero_mask/mask_1 are allocated on CPU; concatenating with CUDA attn_mask will crash.
- For short sequences,
mask_0[:, :, iter-2]
can index OOB whenstep > s+1
.- Unnecessary full clone before slicing; and
iter
shadows the Python builtin.Apply this diff:
s = attn_mask.shape[-1] -for iter in range(2, step + 1): - # iter starts from 2nd step - zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): - mask_1[:, :, i, i] = False +max_step = min(int(step), s + 1) +if max_step != step: + warnings.warn( + f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}." + ) +for step_idx in range(2, max_step + 1): + # step_idx starts from 2nd step + zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s) + mask_0 = attn_mask[:, :, -s:, :].clone() + row_idx = step_idx - 2 + if row_idx < s: + mask_0[:, :, row_idx] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s) + for i in range(step_idx - 1, s - 1): + mask_1[:, :, i, i] = FalseOptional follow-up: vectorize the diagonal update later for perf.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
845-849
: Avoid in‑place aliasing; use the original mask as RHS.Assigning from
attn_mask
into an overlapping slice is unnecessary and can be error‑prone. Read fromattention_mask
(the source you just cloned) for clarity and safety.- attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
57f9143
to
66f5c67
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
312-314
: Nit: rename zero_mask → full_mask for clarity.It’s an all-True mask; the current name is misleading. Optional rename within this scope.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
324-326
: Concatenation pattern LGTM.The two-part concat per iteration keeps shapes consistent and is readable.
309-320
: Cap step to base seq length; avoid OOB and rename loop var.For short sequences and large step,
mask_0[:, :, iter - 2, :]
indexes out of bounds. Also, shadowing Python’s built‑initer
is undesirable.Apply:
- s = attn_mask.shape[-1] - for iter in range(2, step + 1): + s = attn_mask.shape[-1] + max_step = min(int(step), s + 1) + if max_step != step: + warnings.warn( + f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}." + ) + for step_idx in range(2, max_step + 1): # iter starts from 2nd step - zero_mask = attn_mask.new_ones( + zero_mask = attn_mask.new_ones( attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s ).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2, :] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_0 = attn_mask[:, :, -s:, :].clone() + mask_0[:, :, step_idx - 2, :] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:].clone() - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(step_idx - 1, s - 1): mask_1[:, :, i, i] = False
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
f84d09a
to
b8cc0de
Compare
What does this PR do?
Type of change: refactor
Overview:
Our current set_multi_step_attn_mask function hardcode attention mask for step=2,3,4 and do not support arbitrary step. This PR generalize it to support arbitrary step > 1.
Usage
# Add a code snippet demonstrating how to use this
Testing
Tested attention mask locally and pass the sandbox regression test.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Performance
Reliability
Compatibility