[ilu/ttx] optimize int8-KV paged prefill: dequant + reuse bf16 FA2#361
[ilu/ttx] optimize int8-KV paged prefill: dequant + reuse bf16 FA2#361AbeFei wants to merge 1 commit into
Conversation
AbeFei
commented
Jun 12, 2026
- Add _dequant_paged_kv_block_kernel and rewrite the int8 dequant prefill to dequantize referenced KV blocks to bf16, then reuse the bf16 FA2 kernel; fall back to the scalar kernel for unsupported shapes. ~100x faster (1199ms -> 11.6ms).
- Pass original (Hkv, D) scales (drop repeat_interleave); cache bf16 dequant buffers.
- Make the GQA packed kernel's autotune-disabled config shared-mem-safe for large groups to fix OOR in CI; fix double-scaling in the non-packed partial-page loop.
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance FlashAttention-2 prefill path that dequantizes int8 KV blocks into a temporary bf16/fp16 cache before running the native FA2 prefill kernel. It also fixes a scale indexing bug by using kv_head_id instead of q_head_id and improves division safety. However, the reviewer identified a critical issue where the TTXPagedPrefillGQAWithKVDequant operator does not pass the required key_cache_dequant and value_cache_dequant buffers, leading to a guaranteed AssertionError crash. The reviewer suggested dynamically allocating these buffers on the fly if they are not provided to ensure robustness and backward compatibility.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| assert key_cache_dequant is not None, "key_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D) | ||
| assert key_cache_dequant.dtype == q.dtype, ( | ||
| "key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | ||
| ) | ||
| assert value_cache_dequant is not None, "value_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D) | ||
| assert value_cache_dequant.dtype == q.dtype, ( | ||
| "value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | ||
| ) |
There was a problem hiding this comment.
The TTXPagedPrefillGQAWithKVDequant operator does not pass key_cache_dequant and value_cache_dequant buffers to paged_attention_prefill_with_kv_dequant. Under the default configuration where use_dequant_fa2_prefill is True, this will cause a guaranteed AssertionError crash because these buffers are required to be pre-allocated.
To make the API robust and backward-compatible, we should allocate these buffers on the fly if they are not provided, rather than strictly asserting their presence. This allows the 100x faster FA2 path to be used even when the caller does not manage/cache the dequantized buffers.
| assert key_cache_dequant is not None, "key_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D) | |
| assert key_cache_dequant.dtype == q.dtype, ( | |
| "key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | |
| ) | |
| assert value_cache_dequant is not None, "value_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D) | |
| assert value_cache_dequant.dtype == q.dtype, ( | |
| "value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | |
| ) | |
| if key_cache_dequant is None: | |
| key_cache_dequant = torch.empty_like(key_cache, dtype=q.dtype) | |
| else: | |
| assert key_cache_dequant.dtype == q.dtype, ( | |
| "key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | |
| ) | |
| if value_cache_dequant is None: | |
| value_cache_dequant = torch.empty_like(value_cache, dtype=q.dtype) | |
| else: | |
| assert value_cache_dequant.dtype == q.dtype, ( | |
| "value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path" | |
| ) |
Claude Code ReviewVerdict: Request changes -- New dequant-FA2 prefill path requires pre-allocated buffers but the public op wrapper does not appear to provide them, and SummaryAdds a dequant-then-bf16-FA2 prefill path for int8 paged KV: a new Must fix
SuggestionsSuggestions (4)
NitsNits (2)
Notes
|
b27d50d to
9682dc2
Compare
Claude Code ReviewVerdict: Request changes -- Scalar fallback path now uses kv_head_id-indexed scales but receives per-Hq expanded scales no longer (good), yet the decode path and other callers may still pass Hq-expanded; also a likely shape bug in the dequant kernel's scale indexing for the fallback. SummaryAdds a dequant-then-bf16-FA2 prefill path for int8 paged KV by introducing a dedicated dequant kernel that materializes only referenced blocks, plus a numerically safer division in the FA2/decode kernels. Also fixes scale indexing in the existing scalar dequant prefill kernel from Must fix
SuggestionsSuggestions (4)
NitsNits (2)
Notes
|
- Add _dequant_paged_kv_block_kernel and rewrite the int8 dequant prefill to dequantize referenced KV blocks to bf16, then reuse the bf16 FA2 kernel; fall back to the scalar kernel for unsupported shapes. ~100x faster (1199ms -> 11.6ms). - Pass original (Hkv, D) scales (drop repeat_interleave); cache bf16 dequant buffers. - Make the GQA packed kernel's autotune-disabled config shared-mem-safe for large groups to fix OOR in CI; fix double-scaling in the non-packed partial-page loop.
9682dc2 to
f0f4181
Compare
Claude Code ReviewVerdict: Request changes -- The dequant prefill path always allocates two full bf16 KV-cache buffers even when the kernel chooses the scalar fallback, wasting large amounts of memory on the hot path. SummaryAdds a dequant-then-bf16-FA2 prefill path for int8 paged KV: a new kernel materializes only referenced int8 blocks into a bf16/fp16 paged cache of identical layout, then reuses the existing bf16 FA2 prefill kernel. Also fixes an indexing bug where the scalar dequant prefill kernel loaded K/V scales by Must fix
SuggestionsSuggestions (4)
NitsNits (2)
Notes
|