1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
3
# Authors:
4
- # - Burkhard Ringlein
5
- # - Jan van Lunteren
6
- # - Thomas Parnell
4
+ # - Burkhard Ringlein <[email protected] >
5
+ # - Jan van Lunteren <[email protected] >
6
+ # - Chih-Chieh Yang <[email protected] >
7
+ # - Thomas Parnell <[email protected] >
7
8
8
9
import torch
9
10
import triton
@@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
31
32
v_scale , # float32
32
33
num_query_heads : tl .constexpr , # int
33
34
num_queries_per_kv : tl .constexpr , # int
35
+ num_queries_per_kv_padded : tl .constexpr , # int
34
36
block_table_stride : tl .constexpr , # int
35
37
query_stride_0 : tl .constexpr , # int
36
38
query_stride_1 : tl .constexpr , # int, should be equal to head_size
@@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
55
57
query_start_len_ptr , # [num_seqs+1]
56
58
):
57
59
seq_idx = tl .program_id (0 )
58
- query_head_idx = tl .program_id (1 )
59
- kv_head_idx = query_head_idx // num_queries_per_kv
60
+ kv_head_idx = tl .program_id (1 )
60
61
61
62
if filter_by_query_len :
62
63
cur_batch_in_all_start_index = tl .load (query_start_len_ptr + seq_idx )
@@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
69
70
else :
70
71
cur_batch_in_all_start_index = seq_idx
71
72
73
+ query_head_idx = kv_head_idx * num_queries_per_kv + tl .arange (
74
+ 0 , num_queries_per_kv_padded )
75
+
72
76
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
73
- query_head_idx * query_stride_1 )
77
+ query_head_idx [:, None ] * query_stride_1 )
78
+
79
+ head_mask = query_head_idx < (kv_head_idx + 1 ) * num_queries_per_kv
80
+ head_mask = head_mask & (query_head_idx < num_query_heads )
74
81
75
82
dim_mask = tl .where (tl .arange (0 , HEAD_SIZE_PADDED ) < HEAD_SIZE , 1 ,
76
83
0 ).to (tl .int1 )
77
84
78
- # Q : (HEAD_SIZE,)
85
+ # Q : (num_queries_per_kv, HEAD_SIZE,)
79
86
Q = tl .load (
80
- query_ptr + query_offset + tl .arange (0 , HEAD_SIZE_PADDED ),
81
- mask = dim_mask ,
87
+ query_ptr + query_offset + tl .arange (0 , HEAD_SIZE_PADDED )[ None , :] ,
88
+ mask = dim_mask [ None , :] & head_mask [:, None ] ,
82
89
other = 0.0 ,
83
90
)
84
91
85
92
block_table_offset = seq_idx * block_table_stride
86
93
87
- M = tl .full ([1 ], float ("-inf" ), dtype = tl .float32 )
88
- L = tl .full ([1 ], 1.0 , dtype = tl .float32 )
89
- acc = tl .zeros ([HEAD_SIZE_PADDED ], dtype = tl .float32 )
94
+ M = tl .full ([num_queries_per_kv_padded ], float ("-inf" ), dtype = tl .float32 )
95
+ L = tl .full ([num_queries_per_kv_padded ], 1.0 , dtype = tl .float32 )
96
+ acc = tl .zeros ([num_queries_per_kv_padded , HEAD_SIZE_PADDED ],
97
+ dtype = tl .float32 )
90
98
91
99
# sequence len for this particular sequence
92
100
seq_len = tl .load (seq_lens_ptr + seq_idx )
93
101
94
102
# alibi slope for this head
95
103
if USE_ALIBI_SLOPES :
96
- alibi_slope = tl .load (alibi_slopes_ptr + query_head_idx )
104
+ alibi_slope = tl .load (alibi_slopes_ptr + query_head_idx ,
105
+ mask = head_mask ,
106
+ other = 0.0 )
97
107
98
108
num_blocks = cdiv_fn (seq_len , BLOCK_SIZE )
99
109
@@ -107,8 +117,8 @@ def kernel_paged_attention_2d(
107
117
108
118
v_offset = (physical_block_idx * stride_v_cache_0 +
109
119
kv_head_idx * stride_v_cache_1 +
110
- offs_d [:, None ] * stride_v_cache_2 +
111
- offs_n [None , : ] * stride_v_cache_3 )
120
+ offs_d [None , : ] * stride_v_cache_2 +
121
+ offs_n [:, None ] * stride_v_cache_3 )
112
122
113
123
k_offset = (physical_block_idx * stride_k_cache_0 +
114
124
kv_head_idx * stride_k_cache_1 +
@@ -126,61 +136,69 @@ def kernel_paged_attention_2d(
126
136
else :
127
137
K = K_load
128
138
129
- # V : (HEAD_SIZE, BLOCK_SIZE )
139
+ # V : (BLOCK_SIZE, HEAD_SIZE )
130
140
V_load = tl .load (value_cache_ptr + v_offset ,
131
- mask = dim_mask [:, None ],
141
+ mask = dim_mask [None , : ],
132
142
other = 0.0 )
133
143
134
144
if V_load .dtype .is_fp8 ():
135
145
V = (V_load .to (tl .float32 ) * tl .load (v_scale )).to (Q .dtype )
136
146
else :
137
147
V = V_load
138
148
139
- tmp = j * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
149
+ seq_offset = j * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
140
150
boundary = tl .full ([BLOCK_SIZE ], seq_len , dtype = tl .int32 )
141
- mask_new = tmp < boundary
142
- # S : (BLOCK_SIZE,)
143
- S = tl .where (mask_new , 0.0 , float ("-inf" )).to (tl .float32 )
144
- S += scale * tl .sum (K * Q [:, None ], axis = 0 )
151
+ seq_mask = seq_offset [None , :] < boundary
152
+
153
+ # S : (num_queries_per_kv, BLOCK_SIZE,)
154
+ S = tl .where (head_mask [:, None ] & seq_mask , 0.0 ,
155
+ float ("-inf" )).to (tl .float32 )
156
+ S += scale * tl .dot (Q , K )
157
+
158
+ context_len = seq_len - 1
145
159
146
160
if SLIDING_WINDOW > 0 :
147
- S = tl .where ((seq_len - 1 - tmp ) < SLIDING_WINDOW , S , - 10000 )
161
+ S = tl .where ((context_len - seq_offset ) < SLIDING_WINDOW , S ,
162
+ - 10000 )
148
163
149
164
if USE_ALIBI_SLOPES :
150
- S += alibi_slope * (tmp - seq_len + 1 )
165
+ S += alibi_slope [:, None ] * (seq_offset - context_len )
151
166
152
167
# compute running maximum
153
- # m_j : (1 ,)
154
- m_j = tl .maximum (M , tl .max (S , axis = 0 ))
168
+ # m_j : (num_queries_per_kv ,)
169
+ m_j = tl .maximum (M , tl .max (S , axis = 1 ))
155
170
156
- # P : (BLOCK_SIZE,)
157
- P = tl .exp (S - m_j )
171
+ # P : (num_queries_per_kv, BLOCK_SIZE,)
172
+ P = tl .exp (S - m_j [:, None ] )
158
173
159
- # l_j : (1 ,)
160
- l_j = tl .sum (P , axis = 0 )
174
+ # l_j : (num_queries_per_kv ,)
175
+ l_j = tl .sum (P , axis = 1 )
161
176
162
- # alpha : (1 , )
177
+ # alpha : (num_queries_per_kv , )
163
178
alpha = tl .exp (M - m_j )
164
179
165
- # acc : (BLOCK_SIZE,)
166
- acc = acc * alpha
180
+ # acc : (num_queries_per_kv, BLOCK_SIZE,)
181
+ acc = acc * alpha [:, None ]
167
182
168
183
# update constants
169
184
L = L * alpha + l_j
170
185
M = m_j
171
186
172
- # acc : (BLOCK_SIZE,)
173
- acc += tl .sum ( V * P [ None , :], axis = 1 )
187
+ # acc : (num_queries_per_kv, BLOCK_SIZE,)
188
+ acc += tl .dot ( P . to ( V . dtype ), V )
174
189
175
190
# epilogue
176
- acc = acc / L
191
+ acc = acc / L [:, None ]
177
192
178
193
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
179
194
query_head_idx * output_stride_1 )
180
195
181
- tl .store (output_ptr + output_offset + tl .arange (0 , HEAD_SIZE_PADDED ),
182
- acc ,
183
- mask = dim_mask )
196
+ tl .store (
197
+ output_ptr + output_offset [:, None ] +
198
+ tl .arange (0 , HEAD_SIZE_PADDED )[None , :],
199
+ acc ,
200
+ mask = dim_mask [None , :] & head_mask [:, None ],
201
+ )
184
202
185
203
186
204
def chunked_prefill_paged_decode (
@@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
234
252
block_size = value_cache .shape [3 ]
235
253
num_seqs = len (seq_lens )
236
254
num_query_heads = query .shape [1 ]
255
+ num_kv_heads = key .shape [1 ]
237
256
num_queries_per_kv = query .shape [1 ] // key .shape [1 ]
238
257
head_size = query .shape [2 ]
239
258
@@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
253
272
key_cache = key_cache .view (target_dtype )
254
273
value_cache = value_cache .view (target_dtype )
255
274
275
+ num_queries_per_kv_padded = max (triton .next_power_of_2 (num_queries_per_kv ),
276
+ 16 )
277
+
256
278
kernel_paged_attention_2d [(
257
279
num_seqs ,
258
- num_query_heads ,
280
+ num_kv_heads ,
259
281
)](
260
282
output_ptr = output ,
261
283
query_ptr = query ,
@@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
269
291
v_scale = v_scale ,
270
292
num_query_heads = num_query_heads ,
271
293
num_queries_per_kv = num_queries_per_kv ,
294
+ num_queries_per_kv_padded = num_queries_per_kv_padded ,
272
295
block_table_stride = block_table .stride (0 ),
273
296
query_stride_0 = query .stride (0 ),
274
297
query_stride_1 = query .stride (1 ),
0 commit comments