swa_paged_prefill opt:merge mask#376
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the sliding window attention (SWA) paged prefill kernel by replacing the complex inline mask generation logic with a single pre-computed causal mask with windowing. This simplifies the Triton kernels and removes the skip_window_mask parameter. Feedback is provided regarding a potential TypeError in get_mask_causal_with_window when handling None values for window sizes, as well as a performance bottleneck caused by creating the mask on the CPU and copying it to the NPU on every forward pass. It is recommended to create the mask directly on the target device and cache the generated masks to avoid redundant allocations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def get_mask_causal_with_window( | ||
| BLOCK_M: int, | ||
| BLOCK_N: int, | ||
| local_window_size: Optional[int] = None, | ||
| global_window_size: Optional[int] = None, | ||
| ): | ||
| M = global_window_size + local_window_size + 4 * max(BLOCK_M, BLOCK_N) | ||
| N = global_window_size + local_window_size + 5 * max(BLOCK_M, BLOCK_N) | ||
|
|
||
| causal = torch.ones(M, N, dtype=torch.bool).tril() | ||
|
|
||
| sink_band = torch.zeros(M, N, dtype=torch.bool) | ||
| sink_band[:, :global_window_size] = True | ||
|
|
||
| local_band = torch.ones(M, N, dtype=torch.bool).triu(diagonal=-local_window_size) | ||
|
|
||
| return causal & (sink_band | local_band) |
There was a problem hiding this comment.
Performance & Correctness Issue: CPU-to-NPU overhead and potential TypeError
- Potential
TypeError: The type hints forlocal_window_sizeandglobal_window_sizeareOptional[int]. If either isNone, computingMandNwill raise aTypeError(e.g.,TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'). - Performance Bottleneck: This function is called on every single forward pass of the attention kernel. Creating a large mask tensor (e.g.,
5000x5000elements) on the CPU and then copying it to the NPU via.npu()introduces a massive CPU-to-GPU/NPU transfer overhead and memory allocation churn, which will severely degrade performance.
Solution
- Handle
Nonevalues gracefully (whenlocal_window_sizeisNone, there is no local window constraint, so the mask is simply a causal mask). - Create the tensors directly on the target device (
device="npu"). - Cache the generated masks using a global dictionary cache to avoid redundant allocations and transfers on subsequent forward passes.
_causal_mask_cache = {}
def get_mask_causal_with_window(
BLOCK_M: int,
BLOCK_N: int,
local_window_size: Optional[int] = None,
global_window_size: Optional[int] = None,
device: str = "npu",
):
key = (BLOCK_M, BLOCK_N, local_window_size, global_window_size, device)
if key in _causal_mask_cache:
return _causal_mask_cache[key]
g_vol = global_window_size if global_window_size is not None else 0
if local_window_size is None:
M = g_vol + 4 * max(BLOCK_M, BLOCK_N)
N = g_vol + 5 * max(BLOCK_M, BLOCK_N)
mask = torch.ones(M, N, dtype=torch.bool, device=device).tril()
_causal_mask_cache[key] = mask
return mask
M = g_vol + local_window_size + 4 * max(BLOCK_M, BLOCK_N)
N = g_vol + local_window_size + 5 * max(BLOCK_M, BLOCK_N)
causal = torch.ones(M, N, dtype=torch.bool, device=device).tril()
sink_band = torch.zeros(M, N, dtype=torch.bool, device=device)
sink_band[:, :g_vol] = True
local_band = torch.ones(M, N, dtype=torch.bool, device=device).triu(diagonal=-local_window_size)
mask = causal & (sink_band | local_band)
_causal_mask_cache[key] = mask
return mask
No description provided.