Skip to content

swa_paged_prefill opt:merge mask#376

Open
Todobe wants to merge 3 commits into
XPU-Forces:hw/950-perffrom
Todobe:hw/prefill_swa_opt
Open

swa_paged_prefill opt:merge mask#376
Todobe wants to merge 3 commits into
XPU-Forces:hw/950-perffrom
Todobe:hw/prefill_swa_opt

Conversation

@Todobe

@Todobe Todobe commented Jun 26, 2026

Copy link
Copy Markdown

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the sliding window attention (SWA) paged prefill kernel by replacing the complex inline mask generation logic with a single pre-computed causal mask with windowing. This simplifies the Triton kernels and removes the skip_window_mask parameter. Feedback is provided regarding a potential TypeError in get_mask_causal_with_window when handling None values for window sizes, as well as a performance bottleneck caused by creating the mask on the CPU and copying it to the NPU on every forward pass. It is recommended to create the mask directly on the target device and cache the generated masks to avoid redundant allocations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +42 to +58
def get_mask_causal_with_window(
BLOCK_M: int,
BLOCK_N: int,
local_window_size: Optional[int] = None,
global_window_size: Optional[int] = None,
):
M = global_window_size + local_window_size + 4 * max(BLOCK_M, BLOCK_N)
N = global_window_size + local_window_size + 5 * max(BLOCK_M, BLOCK_N)

causal = torch.ones(M, N, dtype=torch.bool).tril()

sink_band = torch.zeros(M, N, dtype=torch.bool)
sink_band[:, :global_window_size] = True

local_band = torch.ones(M, N, dtype=torch.bool).triu(diagonal=-local_window_size)

return causal & (sink_band | local_band)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performance & Correctness Issue: CPU-to-NPU overhead and potential TypeError

  1. Potential TypeError: The type hints for local_window_size and global_window_size are Optional[int]. If either is None, computing M and N will raise a TypeError (e.g., TypeError: unsupported operand type(s) for +: 'int' and 'NoneType').
  2. Performance Bottleneck: This function is called on every single forward pass of the attention kernel. Creating a large mask tensor (e.g., 5000x5000 elements) on the CPU and then copying it to the NPU via .npu() introduces a massive CPU-to-GPU/NPU transfer overhead and memory allocation churn, which will severely degrade performance.

Solution

  • Handle None values gracefully (when local_window_size is None, there is no local window constraint, so the mask is simply a causal mask).
  • Create the tensors directly on the target device (device="npu").
  • Cache the generated masks using a global dictionary cache to avoid redundant allocations and transfers on subsequent forward passes.
_causal_mask_cache = {}


def get_mask_causal_with_window(
        BLOCK_M: int,
        BLOCK_N: int,
        local_window_size: Optional[int] = None,
        global_window_size: Optional[int] = None,
        device: str = "npu",
):
    key = (BLOCK_M, BLOCK_N, local_window_size, global_window_size, device)
    if key in _causal_mask_cache:
        return _causal_mask_cache[key]

    g_vol = global_window_size if global_window_size is not None else 0
    if local_window_size is None:
        M = g_vol + 4 * max(BLOCK_M, BLOCK_N)
        N = g_vol + 5 * max(BLOCK_M, BLOCK_N)
        mask = torch.ones(M, N, dtype=torch.bool, device=device).tril()
        _causal_mask_cache[key] = mask
        return mask

    M = g_vol + local_window_size + 4 * max(BLOCK_M, BLOCK_N)
    N = g_vol + local_window_size + 5 * max(BLOCK_M, BLOCK_N)

    causal = torch.ones(M, N, dtype=torch.bool, device=device).tril()

    sink_band = torch.zeros(M, N, dtype=torch.bool, device=device)
    sink_band[:, :g_vol] = True

    local_band = torch.ones(M, N, dtype=torch.bool, device=device).triu(diagonal=-local_window_size)

    mask = causal & (sink_band | local_band)
    _causal_mask_cache[key] = mask
    return mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant