Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 186 additions & 12 deletions mojo_opset/backends/ttx/kernels/ilu/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from .utils import LOG2E, libentry, smart_triton_autotune

# Max page size for int8 dequant + bf16 FA2 prefill. Larger pages (e.g. 256) fall
# back to the scalar kernel because bf16 FA2 + int8 dequant drifts near the accuracy
# threshold (see M_BF16_ODD_KV_333 in test_attention_quant.py).
_DEQUANT_FA2_MAX_PAGE_SIZE = 128


@triton.jit
def inline_causal_mask(q_blk_start, kv_blk_start, kv_cache_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
Expand Down Expand Up @@ -214,8 +219,9 @@ def _paged_prefill_fav2_kernel(
K_T_blk_ptr = tl.advance(K_T_blk_ptr, (0, BLOCK_N))
V_blk_ptr = tl.advance(V_blk_ptr, (BLOCK_N, 0))

l_i = tl.maximum(l_i, 1e-6)
acc = acc / l_i[:, None]
# l_i_safe >= 1 by construction, so acc / l_i_safe is always well-defined.
l_i_safe = tl.where(l_i > 0, l_i, 1.0)
acc = tl.where((l_i > 0)[:, None], acc / l_i_safe[:, None], 0.0)

O_blk_ptr = tl.make_block_ptr(
base=Out + (q_start + q_blk_start) * stride_ot + q_head_id * stride_oh,
Expand Down Expand Up @@ -475,8 +481,9 @@ def _paged_decode_gqa_kernel(
acc = acc * alpha + tl.sum(p[:, None] * v.to(tl.float32), axis=0)
m = m_new

l = tl.maximum(l, 1e-6)
out = acc / l
# l_safe >= 1 by construction, so acc / l_safe is always well-defined.
l_safe = tl.where(l > 0, l, 1.0)
out = tl.where(l > 0, acc / l_safe, 0.0)
tl.store(out_ptrs, out.to(OUT_T), mask=d_mask)


Expand Down Expand Up @@ -541,11 +548,11 @@ def _paged_prefill_with_kv_dequant_kernel(
).to(tl.float32)

k_scale_vec = tl.load(
K_qscale + q_head_id * stride_ks_h + offs_d * stride_ks_d,
K_qscale + kv_head_id * stride_ks_h + offs_d * stride_ks_d,
mask=d_mask, other=0.0,
)
v_scale_vec = tl.load(
V_qscale + q_head_id * stride_vs_h + offs_d * stride_vs_d,
V_qscale + kv_head_id * stride_vs_h + offs_d * stride_vs_d,
mask=d_mask, other=0.0,
)

Expand Down Expand Up @@ -691,6 +698,91 @@ def grid(META):
return out


@triton.jit
def _dequant_paged_kv_block_kernel(
K_in,
V_in,
K_out,
V_out,
K_scale,
V_scale,
seqlens_kv_ptr,
block_tables_ptr,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_kob, stride_koh, stride_kon, stride_kod,
stride_vob, stride_voh, stride_von, stride_vod,
stride_ks_h, stride_ks_d,
stride_vs_h, stride_vs_d,
stride_bt_batch, stride_bt_block,
MAX_BLOCKS: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
BLOCK_P: tl.constexpr,
):
"""Dequantize the int8 paged K/V blocks referenced by ``block_tables`` into
a bf16/fp16 paged cache of identical layout, applying per-channel scales.

Pure load/scale/store; no ``tl.dot`` so the ILU Triton int8->dot compiler
bug is avoided. Only physical blocks within ``seqlens_kv`` are touched, so a
large preallocated KV cache is not fully materialized.

Scales are promoted to fp32 before the multiply; fp32 scales are recommended
for the dequant-FA2 prefill path, though bf16/fp16 scales are accepted.
"""
pid = tl.program_id(0)
kv_head_id = tl.program_id(1)

b_id = pid // MAX_BLOCKS
logical_block = pid % MAX_BLOCKS

kv_seq_len = tl.load(seqlens_kv_ptr + b_id).to(tl.int32)
if logical_block * PAGE_SIZE >= kv_seq_len:
return

physical_block = tl.load(
block_tables_ptr + b_id * stride_bt_batch + logical_block * stride_bt_block
).to(tl.int32)

offs_d = tl.arange(0, BLOCK_D)
d_mask = offs_d < HEAD_DIM

k_scale = tl.load(
K_scale + kv_head_id * stride_ks_h + offs_d * stride_ks_d, mask=d_mask, other=0.0
).to(tl.float32)
v_scale = tl.load(
V_scale + kv_head_id * stride_vs_h + offs_d * stride_vs_d, mask=d_mask, other=0.0
).to(tl.float32)

for p_start in tl.range(0, PAGE_SIZE, BLOCK_P):
offs_p = p_start + tl.arange(0, BLOCK_P)
p_mask = offs_p < PAGE_SIZE
mask = p_mask[:, None] & d_mask[None, :]

k_in_ptrs = (
K_in + physical_block * stride_kb + kv_head_id * stride_kh
+ offs_p[:, None] * stride_kn + offs_d[None, :] * stride_kd
)
v_in_ptrs = (
V_in + physical_block * stride_vb + kv_head_id * stride_vh
+ offs_p[:, None] * stride_vn + offs_d[None, :] * stride_vd
)
k = tl.load(k_in_ptrs, mask=mask, other=0).to(tl.float32) * k_scale[None, :]
v = tl.load(v_in_ptrs, mask=mask, other=0).to(tl.float32) * v_scale[None, :]

k_out_ptrs = (
K_out + physical_block * stride_kob + kv_head_id * stride_koh
+ offs_p[:, None] * stride_kon + offs_d[None, :] * stride_kod
)
v_out_ptrs = (
V_out + physical_block * stride_vob + kv_head_id * stride_voh
+ offs_p[:, None] * stride_von + offs_d[None, :] * stride_vod
)
tl.store(k_out_ptrs, k.to(K_out.dtype.element_ty), mask=mask)
tl.store(v_out_ptrs, v.to(V_out.dtype.element_ty), mask=mask)


def paged_attention_prefill_with_kv_dequant_impl(
q: torch.Tensor,
key_cache: torch.Tensor,
Expand All @@ -706,15 +798,22 @@ def paged_attention_prefill_with_kv_dequant_impl(
max_seqlen_k: Optional[int] = None,
out: Optional[torch.Tensor] = None,
block_tables_i32: Optional[torch.Tensor] = None,
key_cache_dequant: Optional[torch.Tensor] = None,
value_cache_dequant: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Paged prefill attention with int8 KV cache and per-channel scales.

The int8 KV blocks referenced by ``block_tables`` are first dequantized into
a bf16/fp16 paged cache of identical layout, then the regular bf16 FA2 paged
prefill kernel runs on it (same path as bf16 ``paged_attention_prefill_impl``).

Args:
q: (T, Hq, D) bf16/fp16 query tokens.
key_cache: (N_blocks, Hkv, block_size, D) int8 key cache.
k_qscale: (Hkv, D) float32 per-channel key scale.
k_qscale: (Hkv, D) per-channel key scale indexed by ``kv_head_id`` in the
dequant kernel; fp32 recommended for the dequant-FA2 prefill path.
value_cache: (N_blocks, Hkv, block_size, D) int8 value cache.
v_qscale: (Hkv, D) float32 per-channel value scale.
v_qscale: (Hkv, D) per-channel value scale (same indexing as ``k_qscale``).
cu_seqlens_q: (B+1,) int32.
seqlens_kv: (B,) int32 or None.
block_tables: (B, num_blocks) int32.
Expand All @@ -724,10 +823,11 @@ def paged_attention_prefill_with_kv_dequant_impl(
max_seqlen_k: max KV length hint.
out: pre-allocated output tensor (T, Hq, D).
block_tables_i32: pre-converted int32 block_tables.
key_cache_dequant: optional reusable bf16/fp16 K cache buffer.
value_cache_dequant: optional reusable bf16/fp16 V cache buffer.
"""
total_q_tokens, num_q_heads, head_dim = q.shape
_, num_kv_heads, block_size, _ = key_cache.shape
batch_size = cu_seqlens_q.shape[0] - 1

sm_scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(head_dim)

Expand All @@ -737,20 +837,94 @@ def paged_attention_prefill_with_kv_dequant_impl(
seqlens_kv = seqlens_kv.to(torch.int32)

if out is None:
out = torch.empty(total_q_tokens, num_q_heads, head_dim, device=q.device, dtype=q.dtype)
out = torch.empty(total_q_tokens, num_q_heads, head_dim, dtype=q.dtype, device=q.device)
if block_tables_i32 is None:
block_tables_i32 = block_tables.to(torch.int32)

BLOCK_D = triton.next_power_of_2(head_dim)

# Dequant-FA2 prefill path: dequant int8 KV to bf16, then invoke the native
# bf16 FA2 prefill kernel (per-q-head). Feasible when head_dim is power-of-two
# and block_size <= _DEQUANT_FA2_MAX_PAGE_SIZE (see module constant).
use_dequant_fa2_prefill = (head_dim == BLOCK_D) and (block_size <= _DEQUANT_FA2_MAX_PAGE_SIZE)
if not use_dequant_fa2_prefill:
import logging
logging.getLogger(__name__).warning(
"paged_attention_prefill_with_kv_dequant: falling back to scalar path "
"(head_dim=%d, block_size=%d, max_block_for_fa2=%d). "
"Performance may be significantly degraded.",
head_dim, block_size, _DEQUANT_FA2_MAX_PAGE_SIZE,
)

if use_dequant_fa2_prefill:
if key_cache_dequant is None:
key_cache_dequant = torch.empty_like(key_cache, dtype=q.dtype)
else:
assert key_cache_dequant.dtype == q.dtype, (
"key_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)
if value_cache_dequant is None:
value_cache_dequant = torch.empty_like(value_cache, dtype=q.dtype)
else:
assert value_cache_dequant.dtype == q.dtype, (
"value_cache_dequant dtype must match query dtype for the dequant-FA2 prefill path"
)

# ILU Triton needs power-of-2 vector tiles; the page mask covers any tail.
BLOCK_P = min(triton.next_power_of_2(block_size), 128)
max_blocks = block_tables_i32.shape[1]
if max_blocks == 0:
return out

dequant_grid = (block_tables_i32.shape[0] * max_blocks, num_kv_heads)
_dequant_paged_kv_block_kernel[dequant_grid](
key_cache,
value_cache,
key_cache_dequant,
value_cache_dequant,
k_qscale,
v_qscale,
seqlens_kv,
block_tables_i32,
key_cache.stride(0), key_cache.stride(1), key_cache.stride(2), key_cache.stride(3),
value_cache.stride(0), value_cache.stride(1), value_cache.stride(2), value_cache.stride(3),
key_cache_dequant.stride(0), key_cache_dequant.stride(1),
key_cache_dequant.stride(2), key_cache_dequant.stride(3),
value_cache_dequant.stride(0), value_cache_dequant.stride(1),
value_cache_dequant.stride(2), value_cache_dequant.stride(3),
k_qscale.stride(0), k_qscale.stride(1),
v_qscale.stride(0), v_qscale.stride(1),
block_tables_i32.stride(0), block_tables_i32.stride(1),
max_blocks,
HEAD_DIM=head_dim,
BLOCK_D=BLOCK_D,
PAGE_SIZE=block_size,
BLOCK_P=BLOCK_P,
)

return paged_attention_prefill_impl(
q,
key_cache_dequant,
value_cache_dequant,
cu_seqlens_q,
seqlens_kv,
block_tables_i32,
gqa_interleave,
sm_scale,
max_q_len=max_seqlen_q,
max_total_seq_len=max_seqlen_k,
out=out,
block_tables_i32=block_tables_i32,
)

# Fallback: scalar per-token dequant attention (handles any head_dim / page).
if max_seqlen_q is not None:
max_q_len = max_seqlen_q
else:
q_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_q_len = q_lens.max().item()

grid = (max_q_len, num_q_heads, batch_size)

grid = (max_q_len, num_q_heads, cu_seqlens_q.shape[0] - 1)
_paged_prefill_with_kv_dequant_kernel[grid](
q,
key_cache,
Expand Down
25 changes: 10 additions & 15 deletions mojo_opset/backends/ttx/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,33 +107,28 @@ def forward(
else cu_total_seq_lens[1:] - cu_total_seq_lens[:-1]
)

num_q_heads = query.shape[1]
num_kv_heads = key_cache.shape[1]

if num_q_heads != num_kv_heads:
if self.gqa_layout == "AABB":
k_qscale_expanded = key_scale.repeat_interleave(num_q_heads // num_kv_heads, dim=0)
v_qscale_expanded = value_scale.repeat_interleave(num_q_heads // num_kv_heads, dim=0)
else:
k_qscale_expanded = key_scale.repeat((num_q_heads // num_kv_heads, 1))
v_qscale_expanded = value_scale.repeat((num_q_heads // num_kv_heads, 1))
else:
k_qscale_expanded = key_scale
v_qscale_expanded = value_scale
# Pre-allocate dequant buffers for all paths (the kernel internally
# decides whether to use them based on head_dim / block_size).
key_cache_dequant = torch.empty_like(key_cache, dtype=query.dtype)
value_cache_dequant = torch.empty_like(value_cache, dtype=query.dtype)

# Dequant kernel and scalar fallback index scales by kv_head_id, so pass
# per-channel (Hkv, D) scales directly (no expansion to Hq).
output = paged_attention_prefill_with_kv_dequant(
q=query,
key_cache=key_cache,
k_qscale=k_qscale_expanded,
k_qscale=key_scale,
value_cache=value_cache,
v_qscale=v_qscale_expanded,
v_qscale=value_scale,
cu_seqlens_q=cu_q_lens,
seqlens_kv=seqlens_kv,
block_tables=block_tables,
gqa_interleave=self.gqa_layout == "ABAB",
softmax_scale=softmax_scale,
max_seqlen_q=max_q_len,
max_seqlen_k=max_total_seq_len,
key_cache_dequant=key_cache_dequant,
value_cache_dequant=value_cache_dequant,
)

return output
Expand Down
Loading