Skip to content

[ilu/ttx] optimize int8-KV paged prefill: dequant + reuse bf16 FA2#361

Open
AbeFei wants to merge 1 commit into
masterfrom
ilu/optimize_quant_fla_attn
Open

[ilu/ttx] optimize int8-KV paged prefill: dequant + reuse bf16 FA2#361
AbeFei wants to merge 1 commit into
masterfrom
ilu/optimize_quant_fla_attn

Conversation

@AbeFei

@AbeFei AbeFei commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator
  • Add _dequant_paged_kv_block_kernel and rewrite the int8 dequant prefill to dequantize referenced KV blocks to bf16, then reuse the bf16 FA2 kernel; fall back to the scalar kernel for unsupported shapes. ~100x faster (1199ms -> 11.6ms).
  • Pass original (Hkv, D) scales (drop repeat_interleave); cache bf16 dequant buffers.
  • Make the GQA packed kernel's autotune-disabled config shared-mem-safe for large groups to fix OOR in CI; fix double-scaling in the non-packed partial-page loop.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance FlashAttention-2 prefill path that dequantizes int8 KV blocks into a temporary bf16/fp16 cache before running the native FA2 prefill kernel. It also fixes a scale indexing bug by using kv_head_id instead of q_head_id and improves division safety. However, the reviewer identified a critical issue where the TTXPagedPrefillGQAWithKVDequant operator does not pass the required key_cache_dequant and value_cache_dequant buffers, leading to a guaranteed AssertionError crash. The reviewer suggested dynamically allocating these buffers on the fly if they are not provided to ensure robustness and backward compatibility.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +849 to +856
assert key_cache_dequant is not None, "key_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D)
assert key_cache_dequant.dtype == q.dtype, (
"key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)
assert value_cache_dequant is not None, "value_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D)
assert value_cache_dequant.dtype == q.dtype, (
"value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TTXPagedPrefillGQAWithKVDequant operator does not pass key_cache_dequant and value_cache_dequant buffers to paged_attention_prefill_with_kv_dequant. Under the default configuration where use_dequant_fa2_prefill is True, this will cause a guaranteed AssertionError crash because these buffers are required to be pre-allocated.

To make the API robust and backward-compatible, we should allocate these buffers on the fly if they are not provided, rather than strictly asserting their presence. This allows the 100x faster FA2 path to be used even when the caller does not manage/cache the dequantized buffers.

Suggested change
assert key_cache_dequant is not None, "key_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D)
assert key_cache_dequant.dtype == q.dtype, (
"key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)
assert value_cache_dequant is not None, "value_cache_dequant must be pre-allocated" # shape: (N_blocks, Hkv, block_size, D)
assert value_cache_dequant.dtype == q.dtype, (
"value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)
if key_cache_dequant is None:
key_cache_dequant = torch.empty_like(key_cache, dtype=q.dtype)
else:
assert key_cache_dequant.dtype == q.dtype, (
"key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)
if value_cache_dequant is None:
value_cache_dequant = torch.empty_like(value_cache, dtype=q.dtype)
else:
assert value_cache_dequant.dtype == q.dtype, (
"value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- New dequant-FA2 prefill path requires pre-allocated buffers but the public op wrapper does not appear to provide them, and out is now mandatory which is an API break.

Summary

Adds a dequant-then-bf16-FA2 prefill path for int8 paged KV: a new _dequant_paged_kv_block_kernel materializes referenced int8 blocks into a bf16/fp16 paged cache and then dispatches to the existing bf16 FA2 prefill. Also fixes scale indexing (q_head_id -> kv_head_id) in the scalar dequant kernel and replaces tl.maximum(l, 1e-6) with a zero-safe divide.

Must fix

  • [BLOCKER] out made mandatory without caller update -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:837 -- The previous behavior allocated out when None; switching to an assert out is not None is an API break. Either keep auto-allocation or update all call sites and document the contract.
  • [BLOCKER] Dequant-FA2 path requires buffers the operator wrapper does not pass -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:846-852 -- key_cache_dequant/value_cache_dequant are asserted non-None on the default prefill path (head_dim power-of-two, block_size <= 128 is common), but the diff to operators/attention.py shows no plumbing of these buffers, so calls will hit the assertion. Confirm wrapper path or fall back gracefully when buffers are absent.
  • [BLOCKER] Trailing newline removed -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:1354 -- "No newline at end of file" was introduced; restore it.

Suggestions

Suggestions (4)
  • [MAJOR] Dequant kernel launches batch * MAX_BLOCKS programs unconditionally -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:781 -- For sparsely-used block tables this overlaunches; consider compacting via seqlens_kv to reduce wasted programs and avoid loading garbage physical_block ids past the live range (the early-return guards correctness, not cost).
  • [MAJOR] physical_block for unused slots may be uninitialized -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:741-746 -- The early-return on logical_block * PAGE_SIZE >= kv_seq_len is correct, but ensure block_tables_i32 entries past seqlens_kv are not read elsewhere; the bf16 FA2 kernel also indexes by these tables, so any junk entries must be tolerated there too.
  • [MAJOR] Silent accuracy-driven page-size cutoff -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:16-19,840 -- _DEQUANT_FA2_MAX_PAGE_SIZE = 128 silently routes >128 pages to the slow scalar path; consider logging once or surfacing via config so regressions in deployments using larger pages are visible.
  • [MINOR] tl.where(l>0, ...) evaluates acc / l_safe regardless -- flash_attention.py:222-223,483-484 -- Behavior is correct since l_safe>=1, but a comment noting this avoids future "fix" attempts that re-introduce div-by-zero.

Nits

Nits (2)
  • [NIT] int(q_lens.max().item()) -- flash_attention.py:910 -- .item() already returns a Python int for int32 tensors; the cast is redundant.
  • [NIT] Removed blank line before return out -- flash_attention.py:939 -- Minor style churn unrelated to the change.

Notes

  • [CHECK] Confirm callers of paged_attention_prefill_with_kv_dequant_impl (notably any non-TTX path or tests) all pass out= now that it is required.
  • [CHECK] The k/v scale indexing fix (q_head_id -> kv_head_id) changes numerical results for GQA; verify existing tests covered the broken indexing or were passing because callers pre-expanded scales (which the operator change now stops doing).

@AbeFei AbeFei force-pushed the ilu/optimize_quant_fla_attn branch from b27d50d to 9682dc2 Compare June 12, 2026 11:03
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Scalar fallback path now uses kv_head_id-indexed scales but receives per-Hq expanded scales no longer (good), yet the decode path and other callers may still pass Hq-expanded; also a likely shape bug in the dequant kernel's scale indexing for the fallback.

Summary

Adds a dequant-then-bf16-FA2 prefill path for int8 paged KV by introducing a dedicated dequant kernel that materializes only referenced blocks, plus a numerically safer division in the FA2/decode kernels. Also fixes scale indexing in the existing scalar dequant prefill kernel from q_head_id to kv_head_id and stops expanding scales at the operator layer.

Must fix

  • [BLOCKER] Scalar decode path may now receive wrong-shaped scales -- mojo_opset/backends/ttx/operators/attention.py:110-125 -- The operator no longer expands key_scale/value_scale to Hq, but only the prefill kernel was changed to index by kv_head_id. Verify the decode-with-kv-dequant operator (and any other caller of paged_attention_*_with_kv_dequant) is consistent; if _paged_decode_with_kv_dequant_kernel still indexes by q_head_id, GQA results will be silently wrong.
  • [BLOCKER] Trailing newline missing -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:1370 -- File ends with return out and no newline (\ No newline at end of file); add one.

Suggestions

Suggestions (4)
  • [MAJOR] Dequant buffers always allocated even when unused -- mojo_opset/backends/ttx/operators/attention.py:112-114 -- The operator unconditionally allocates empty_like(key_cache, dtype=query.dtype) for the full paged cache, but the kernel may take the scalar fallback and ignore them. For large caches this is a significant wasted allocation; allocate inside the kernel only when use_dequant_fa2_prefill is true (which it already does when key_cache_dequant is None).
  • [MAJOR] block_tables_i32 may contain stale entries past seqlens_kv -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:736-744 -- The dequant kernel guards by logical_block * PAGE_SIZE >= kv_seq_len, but tl.load(block_tables_ptr + ...) is still issued unconditionally before that check on the next lines; the early-return path is fine, but confirm block_tables_i32 is sized to MAX_BLOCKS for every batch (otherwise the grid B*MAX_BLOCKS over-reads).
  • [MAJOR] BLOCK_P=min(next_pow2(block_size), 128) may exceed PAGE_SIZE bounds when block_size>128 -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:870 -- The use_dequant_fa2_prefill gate ensures block_size <= 128, so BLOCK_P==PAGE_SIZE here and the for p_start in tl.range(0, PAGE_SIZE, BLOCK_P) loop runs exactly once; the masked-tail logic is dead. Either drop the loop or remove the min(..., 128) and the gate, but the current code is misleading.
  • [MINOR] import logging inside function -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:849 -- Move to module top; also a per-call warning on the fallback path will spam logs in steady-state inference. Consider warning_once or a one-shot flag.

Nits

Nits (2)
  • [NIT] _DEQUANT_FA2_MAX_PAGE_SIZE = 128 chosen empirically -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:19 -- consider linking the test name in a comment more precisely or filing a tracking issue.
  • [NIT] Reordered dtype= / device= kwargs in torch.empty -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:840 -- pure churn, unrelated to the PR.

Notes

  • [CHECK] Confirm key_cache.is_contiguous() assumptions hold for the dequant kernel's strided pointer arithmetic, especially when callers pass views.
  • [CHECK] The tl.where((l_i > 0)[:, None], acc / l_safe, 0.0) change masks divide-by-zero with zeros instead of 1e-6-clamped output; verify downstream code (e.g., logsumexp returns) does not rely on the previous nonzero-output behavior for fully-masked rows.

- Add _dequant_paged_kv_block_kernel and rewrite the int8 dequant prefill
  to dequantize referenced KV blocks to bf16, then reuse the bf16 FA2
  kernel; fall back to the scalar kernel for unsupported shapes.
  ~100x faster (1199ms -> 11.6ms).
- Pass original (Hkv, D) scales (drop repeat_interleave); cache bf16
  dequant buffers.
- Make the GQA packed kernel's autotune-disabled config shared-mem-safe
  for large groups to fix OOR in CI; fix double-scaling in the non-packed
  partial-page loop.
@AbeFei AbeFei force-pushed the ilu/optimize_quant_fla_attn branch from 9682dc2 to f0f4181 Compare June 15, 2026 02:25
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The dequant prefill path always allocates two full bf16 KV-cache buffers even when the kernel chooses the scalar fallback, wasting large amounts of memory on the hot path.

Summary

Adds a dequant-then-bf16-FA2 prefill path for int8 paged KV: a new kernel materializes only referenced int8 blocks into a bf16/fp16 paged cache of identical layout, then reuses the existing bf16 FA2 prefill kernel. Also fixes an indexing bug where the scalar dequant prefill kernel loaded K/V scales by q_head_id instead of kv_head_id, and hardens the divide-by-zero guard in two attention kernels.

Must fix

  • [BLOCKER] Dequant buffers always allocated, even on scalar fallback -- mojo_opset/backends/ttx/operators/attention.py:112-114 -- The operator unconditionally allocates empty_like(key_cache, dtype=query.dtype) for both K and V (typically many GB for a real KV cache), but the kernel only consults them when use_dequant_fa2_prefill is true. Move allocation inside the kernel (current if key_cache_dequant is None branch already handles it) or replicate the feasibility check at the caller.
  • [BLOCKER] Prior scale-indexing bug not covered by indexing change alone -- mojo_opset/backends/ttx/operators/attention.py:118-122 -- The caller now passes (Hkv, D) scales, which is correct for the new dequant kernel and the fixed _paged_prefill_with_kv_dequant_kernel, but any other code path or external caller relying on the prior (Hq, D) expansion will silently read OOB/garbage. Worth asserting key_scale.shape[0] == num_kv_heads at the entry point.

Suggestions

Suggestions (4)
  • [MAJOR] tl.range with BLOCK_P == PAGE_SIZE wastes a loop -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:761 -- When block_size is a power of two and <=128, BLOCK_P == PAGE_SIZE so the for-loop runs exactly once; consider dropping the loop (or asserting) to keep the IR simpler. Also, the min(..., 128) cap is redundant with the module-level _DEQUANT_FA2_MAX_PAGE_SIZE = 128 guard.
  • [MAJOR] logging.warning per call on a hot path -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:847-853 -- This warns on every prefill invocation that hits the fallback, which will spam logs in inference. Either warn-once (module-level flag) or downgrade to debug.
  • [MINOR] import logging inside function -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:846 -- Move to module top to avoid repeated import overhead and match style elsewhere.
  • [MINOR] acc / l_safe with l_safe = where(l>0, l, 1.0) then masking via where is redundant -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:222-224, 484-486 -- The original tl.maximum(l_i, 1e-6) produced O(1e-6)-magnitude noise rows; the new form is correct but you could simplify to acc = tl.where((l_i > 0)[:, None], acc / l_safe[:, None], 0.0) only (current code is fine, just noting l_safe naming since the comment claims >=1 but l_i between 0 and 1 is possible).

Nits

Nits (2)
  • [NIT] Comment "l_i_safe >= 1 by construction" is inaccurate -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:222 -- l_safe is >= small positive (sum of exp values), not >= 1. Reword.
  • [NIT] Trailing-whitespace style change dtype=q.dtype, device=q.device -- mojo_opset/backends/ttx/kernels/ilu/flash_attention.py:840 -- Unrelated reorder; keep diff minimal.

Notes

  • [CHECK] _DEQUANT_FA2_MAX_PAGE_SIZE = 128 is justified by a single failing test (M_BF16_ODD_KV_333); confirm whether the drift is from the dequant kernel itself or from FA2 accumulation, since the latter would also affect the non-quant prefill path at large page sizes.
  • [CHECK] _dequant_paged_kv_block_kernel only checks logical_block * PAGE_SIZE >= kv_seq_len to skip; if block_tables contains stale/invalid entries past seqlens_kv, this is fine, but partially-filled trailing blocks still dequant garbage past kv_seq_len % PAGE_SIZE. The downstream FA2 kernel must mask those positions -- verify _paged_prefill_fav2_kernel honors seqlens_kv strictly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant