Skip to content

Commit fb4c7f8

Browse files
tdoublepjvlunterenbringleincyang49
authored
[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. (#14431)
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]>
1 parent 0b1cfa6 commit fb4c7f8

File tree

1 file changed

+63
-40
lines changed

1 file changed

+63
-40
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

+63-40
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
# Authors:
4-
# - Burkhard Ringlein
5-
# - Jan van Lunteren
6-
# - Thomas Parnell
4+
# - Burkhard Ringlein <[email protected]>
5+
# - Jan van Lunteren <[email protected]>
6+
# - Chih-Chieh Yang <[email protected]>
7+
# - Thomas Parnell <[email protected]>
78

89
import torch
910
import triton
@@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
3132
v_scale, # float32
3233
num_query_heads: tl.constexpr, # int
3334
num_queries_per_kv: tl.constexpr, # int
35+
num_queries_per_kv_padded: tl.constexpr, # int
3436
block_table_stride: tl.constexpr, # int
3537
query_stride_0: tl.constexpr, # int
3638
query_stride_1: tl.constexpr, # int, should be equal to head_size
@@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
5557
query_start_len_ptr, # [num_seqs+1]
5658
):
5759
seq_idx = tl.program_id(0)
58-
query_head_idx = tl.program_id(1)
59-
kv_head_idx = query_head_idx // num_queries_per_kv
60+
kv_head_idx = tl.program_id(1)
6061

6162
if filter_by_query_len:
6263
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
@@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
6970
else:
7071
cur_batch_in_all_start_index = seq_idx
7172

73+
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
74+
0, num_queries_per_kv_padded)
75+
7276
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
73-
query_head_idx * query_stride_1)
77+
query_head_idx[:, None] * query_stride_1)
78+
79+
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
80+
head_mask = head_mask & (query_head_idx < num_query_heads)
7481

7582
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
7683
0).to(tl.int1)
7784

78-
# Q : (HEAD_SIZE,)
85+
# Q : (num_queries_per_kv, HEAD_SIZE,)
7986
Q = tl.load(
80-
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED),
81-
mask=dim_mask,
87+
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
88+
mask=dim_mask[None, :] & head_mask[:, None],
8289
other=0.0,
8390
)
8491

8592
block_table_offset = seq_idx * block_table_stride
8693

87-
M = tl.full([1], float("-inf"), dtype=tl.float32)
88-
L = tl.full([1], 1.0, dtype=tl.float32)
89-
acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32)
94+
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
95+
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
96+
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
97+
dtype=tl.float32)
9098

9199
# sequence len for this particular sequence
92100
seq_len = tl.load(seq_lens_ptr + seq_idx)
93101

94102
# alibi slope for this head
95103
if USE_ALIBI_SLOPES:
96-
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx)
104+
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
105+
mask=head_mask,
106+
other=0.0)
97107

98108
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
99109

@@ -107,8 +117,8 @@ def kernel_paged_attention_2d(
107117

108118
v_offset = (physical_block_idx * stride_v_cache_0 +
109119
kv_head_idx * stride_v_cache_1 +
110-
offs_d[:, None] * stride_v_cache_2 +
111-
offs_n[None, :] * stride_v_cache_3)
120+
offs_d[None, :] * stride_v_cache_2 +
121+
offs_n[:, None] * stride_v_cache_3)
112122

113123
k_offset = (physical_block_idx * stride_k_cache_0 +
114124
kv_head_idx * stride_k_cache_1 +
@@ -126,61 +136,69 @@ def kernel_paged_attention_2d(
126136
else:
127137
K = K_load
128138

129-
# V : (HEAD_SIZE, BLOCK_SIZE)
139+
# V : (BLOCK_SIZE, HEAD_SIZE)
130140
V_load = tl.load(value_cache_ptr + v_offset,
131-
mask=dim_mask[:, None],
141+
mask=dim_mask[None, :],
132142
other=0.0)
133143

134144
if V_load.dtype.is_fp8():
135145
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
136146
else:
137147
V = V_load
138148

139-
tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
149+
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140150
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
141-
mask_new = tmp < boundary
142-
# S : (BLOCK_SIZE,)
143-
S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32)
144-
S += scale * tl.sum(K * Q[:, None], axis=0)
151+
seq_mask = seq_offset[None, :] < boundary
152+
153+
# S : (num_queries_per_kv, BLOCK_SIZE,)
154+
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
155+
float("-inf")).to(tl.float32)
156+
S += scale * tl.dot(Q, K)
157+
158+
context_len = seq_len - 1
145159

146160
if SLIDING_WINDOW > 0:
147-
S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000)
161+
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
162+
-10000)
148163

149164
if USE_ALIBI_SLOPES:
150-
S += alibi_slope * (tmp - seq_len + 1)
165+
S += alibi_slope[:, None] * (seq_offset - context_len)
151166

152167
# compute running maximum
153-
# m_j : (1,)
154-
m_j = tl.maximum(M, tl.max(S, axis=0))
168+
# m_j : (num_queries_per_kv,)
169+
m_j = tl.maximum(M, tl.max(S, axis=1))
155170

156-
# P : (BLOCK_SIZE,)
157-
P = tl.exp(S - m_j)
171+
# P : (num_queries_per_kv, BLOCK_SIZE,)
172+
P = tl.exp(S - m_j[:, None])
158173

159-
# l_j : (1,)
160-
l_j = tl.sum(P, axis=0)
174+
# l_j : (num_queries_per_kv,)
175+
l_j = tl.sum(P, axis=1)
161176

162-
# alpha : (1, )
177+
# alpha : (num_queries_per_kv, )
163178
alpha = tl.exp(M - m_j)
164179

165-
# acc : (BLOCK_SIZE,)
166-
acc = acc * alpha
180+
# acc : (num_queries_per_kv, BLOCK_SIZE,)
181+
acc = acc * alpha[:, None]
167182

168183
# update constants
169184
L = L * alpha + l_j
170185
M = m_j
171186

172-
# acc : (BLOCK_SIZE,)
173-
acc += tl.sum(V * P[None, :], axis=1)
187+
# acc : (num_queries_per_kv, BLOCK_SIZE,)
188+
acc += tl.dot(P.to(V.dtype), V)
174189

175190
# epilogue
176-
acc = acc / L
191+
acc = acc / L[:, None]
177192

178193
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
179194
query_head_idx * output_stride_1)
180195

181-
tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED),
182-
acc,
183-
mask=dim_mask)
196+
tl.store(
197+
output_ptr + output_offset[:, None] +
198+
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
199+
acc,
200+
mask=dim_mask[None, :] & head_mask[:, None],
201+
)
184202

185203

186204
def chunked_prefill_paged_decode(
@@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
234252
block_size = value_cache.shape[3]
235253
num_seqs = len(seq_lens)
236254
num_query_heads = query.shape[1]
255+
num_kv_heads = key.shape[1]
237256
num_queries_per_kv = query.shape[1] // key.shape[1]
238257
head_size = query.shape[2]
239258

@@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
253272
key_cache = key_cache.view(target_dtype)
254273
value_cache = value_cache.view(target_dtype)
255274

275+
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
276+
16)
277+
256278
kernel_paged_attention_2d[(
257279
num_seqs,
258-
num_query_heads,
280+
num_kv_heads,
259281
)](
260282
output_ptr=output,
261283
query_ptr=query,
@@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
269291
v_scale=v_scale,
270292
num_query_heads=num_query_heads,
271293
num_queries_per_kv=num_queries_per_kv,
294+
num_queries_per_kv_padded=num_queries_per_kv_padded,
272295
block_table_stride=block_table.stride(0),
273296
query_stride_0=query.stride(0),
274297
query_stride_1=query.stride(1),

0 commit comments

Comments
 (0)