Skip to content

Conversation

yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Sep 19, 2025

What does this PR do?

Type of change: refactor

Overview:
Our current set_multi_step_attn_mask function hardcode attention mask for step=2,3,4 and do not support arbitrary step. This PR generalize it to support arbitrary step > 1.

Usage

# Add a code snippet demonstrating how to use this

Testing

Tested attention mask locally and pass the sandbox regression test.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Expanded multi-step speculative decoding support beyond four steps for larger draft counts.
  • Performance

    • More efficient iterative attention-mask handling reduces overhead for higher step values.
  • Reliability

    • Unified mask construction improves consistency and reduces edge-case failures.
  • Compatibility

    • Public interfaces remain unchanged; existing integrations continue to work without modification.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner September 19, 2025 17:57
Copy link

coderabbitai bot commented Sep 19, 2025

Walkthrough

Refactors set_multi_step_attention_mask to remove the step <= 4 assertion and replace step-specific branching with a generalized loop that builds per-iteration masks and concatenates them to produce the final attention mask; public signature unchanged.

Changes

Cohort / File(s) Summary of changes
Multi-step attention mask refactor
modelopt/torch/speculative/plugins/megatron_eagle.py
Removed assert step <= 4; replaced step-specific branching with a generalized loop for iter in range(2, step + 1) that constructs per-iteration zero_mask, mask_0, mask_1 and concatenates them iteratively to build the final attn_mask. Public signature set_multi_step_attention_mask(attn_mask, step) unchanged.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Eagle as MegatronEagle
  rect rgba(220,235,245,0.5)
    note over Eagle: set_multi_step_attention_mask(attn_mask, step)
    Caller->>Eagle: set_multi_step_attention_mask(attn_mask, step)
    loop iter = 2 .. step
      Eagle->>Eagle: build zero_mask, mask_0, mask_1
      Eagle->>Eagle: concat per-iteration masks into attn_mask
    end
    Eagle-->>Caller: updated attn_mask
  end
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

I looped through steps with nimble feet,
No cap to hops — each mask complete.
I stitch the rows and join the seam,
A rabbit's patchwork, swift and clean. 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the PR's primary change: refactoring set_multi_step_attn_mask to support arbitrary step values rather than hardcoded cases. It names the target function and the key intent, is a single concise sentence, and contains no extraneous information. The abbreviation "attn" is clear in this context and does not reduce clarity for a reviewer scanning history.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/refactor_multi_step_attn_mask_for_arbitrary_step

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 66f5c67 and b8cc0de.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)

324-325: LGTM on concat geometry.

The two-stage concat produces the expected (ks, ks) growth per iteration; shapes align correctly along both axes.


310-320: Cap step to avoid OOB; fix overlapping slice copy; minor perf/readability nits.

  • OOB: mask_0[:, :, iter-2, :] indexes past s-1 when step > s+1. Cap the loop bound. This was called out earlier and is still unfixed.
  • Overlapping copy: mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] reads/writes the same view and can corrupt the mask; clone the RHS.
  • Perf/readability: slice before clone (avoid cloning the whole tensor) and rename iter → step_idx.

Apply:

@@
-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
-        # iter starts from 2nd step
+    s = attn_mask.shape[-1]
+    # Bound step to avoid indexing past base sequence length on short sequences.
+    max_step = min(int(step), s + 1)
+    if max_step != step:
+        warnings.warn(
+            f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base sequence length s={s}."
+        )
+    for step_idx in range(2, max_step + 1):
+        # step_idx starts from 2nd step
@@
-        mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-        mask_0[:, :, iter - 2, :] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+        mask_0 = attn_mask[:, :, -s:, :].clone()
+        row_idx = step_idx - 2
+        if row_idx < s:
+            mask_0[:, :, row_idx, :] = True
+        rhs = mask_0[:, :, :, 1:].clone()  # avoid overlapping in-place copy
+        mask_0[:, :, :, :-1] = rhs
@@
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
+        diag_idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device)
+        if diag_idx.numel() > 0:
+            mask_1[:, :, diag_idx, diag_idx] = False

Run locally (in your PyTorch env) to confirm shapes are correct and no OOB for step > s+1:

#!/bin/bash
python - <<'PY'
import torch
from modelopt.torch.speculative.plugins.megatron_eagle import set_multi_step_attention_mask

def make_mask(s, device='cpu'):
    # causal mask: True means masked
    return torch.ones(1,1,s,s, dtype=torch.bool, device=device).triu(1)

for s in [1,2,4,8]:
    base = make_mask(s)
    for step in [2, s+1, s+5]:
        out = set_multi_step_attention_mask(base.clone(), step)
        exp = min(int(step), s+1)*s
        assert tuple(out.shape[-2:]) == (exp, exp), (s, step, out.shape)
print("OK")
PY

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

317-318: Optional: vectorize the diagonal write to avoid a Python loop.

Minor perf tidy-up; reduces Python overhead for large s.

-        for i in range(step_idx - 1, s - 1):
-            mask_1[:, :, i, i] = False
+        if step_idx - 1 < s - 1:
+            idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device)
+            mask_1[:, :, idx, idx] = False
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a3fd29 and 8d81046.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

309-327: Generalization approach looks good.

The iterative block-growth builds an (steps)×(steps) mask and removes the hardcoded branches. Nice cleanup.

Please re-run the sandbox regression with:

  • CUDA tensors (to confirm device fix), and
  • very short sequences (e.g., s ∈ {1,2,3}) with step ∈ {2,3,4} to ensure no out-of-bounds and correct shapes.

Copy link

codecov bot commented Sep 19, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.83%. Comparing base (74061f5) to head (b8cc0de).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #348      +/-   ##
==========================================
- Coverage   73.84%   73.83%   -0.01%     
==========================================
  Files         172      172              
  Lines       17453    17453              
==========================================
- Hits        12888    12887       -1     
- Misses       4565     4566       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

309-327: Fix device/dtype mismatch and short‑sequence OOB; don’t shadow built‑in iter.

  • zero_mask/mask_1 are allocated on CPU; concatenating with CUDA attn_mask will crash.
  • For short sequences, mask_0[:, :, iter-2] can index OOB when step > s+1.
  • Unnecessary full clone before slicing; and iter shadows the Python builtin.

Apply this diff:

 s = attn_mask.shape[-1]
-for iter in range(2, step + 1):
-    # iter starts from 2nd step
-    zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
-    mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-    mask_0[:, :, iter - 2] = True
-    mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-    mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-    for i in range(iter - 1, s - 1):
-        mask_1[:, :, i, i] = False
+max_step = min(int(step), s + 1)
+if max_step != step:
+    warnings.warn(
+        f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}."
+    )
+for step_idx in range(2, max_step + 1):
+    # step_idx starts from 2nd step
+    zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
+    mask_0 = attn_mask[:, :, -s:, :].clone()
+    row_idx = step_idx - 2
+    if row_idx < s:
+        mask_0[:, :, row_idx] = True
+    mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+    mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
+    for i in range(step_idx - 1, s - 1):
+        mask_1[:, :, i, i] = False

Optional follow-up: vectorize the diagonal update later for perf.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

845-849: Avoid in‑place aliasing; use the original mask as RHS.

Assigning from attn_mask into an overlapping slice is unnecessary and can be error‑prone. Read from attention_mask (the source you just cloned) for clarity and safety.

- attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:]
+ attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8d81046 and 9e398c8.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/refactor_multi_step_attn_mask_for_arbitrary_step branch from 57f9143 to 66f5c67 Compare September 22, 2025 19:08
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

312-314: Nit: rename zero_mask → full_mask for clarity.

It’s an all-True mask; the current name is misleading. Optional rename within this scope.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 57f9143 and 66f5c67.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)

324-326: Concatenation pattern LGTM.

The two-part concat per iteration keeps shapes consistent and is readable.


309-320: Cap step to base seq length; avoid OOB and rename loop var.

For short sequences and large step, mask_0[:, :, iter - 2, :] indexes out of bounds. Also, shadowing Python’s built‑in iter is undesirable.

Apply:

-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
+    s = attn_mask.shape[-1]
+    max_step = min(int(step), s + 1)
+    if max_step != step:
+        warnings.warn(
+            f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}."
+        )
+    for step_idx in range(2, max_step + 1):
         # iter starts from 2nd step
-        zero_mask = attn_mask.new_ones(
+        zero_mask = attn_mask.new_ones(
             attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s
         ).bool()
-        mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-        mask_0[:, :, iter - 2, :] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+        mask_0 = attn_mask[:, :, -s:, :].clone()
+        mask_0[:, :, step_idx - 2, :] = True
+        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:].clone()
-        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
+        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
+        for i in range(step_idx - 1, s - 1):
             mask_1[:, :, i, i] = False

@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/refactor_multi_step_attn_mask_for_arbitrary_step branch from f84d09a to b8cc0de Compare September 22, 2025 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant