11# SPDX-License-Identifier: Apache-2.0
22
33# Authors:
4- # - Burkhard Ringlein
5- # - Jan van Lunteren
6- # - Thomas Parnell
4+ # - Burkhard Ringlein <[email protected] > 5+ # - Jan van Lunteren <[email protected] > 6+ # - Chih-Chieh Yang <[email protected] > 7+ # - Thomas Parnell <[email protected] > 78
89import torch
910import triton
@@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
3132 v_scale , # float32
3233 num_query_heads : tl .constexpr , # int
3334 num_queries_per_kv : tl .constexpr , # int
35+ num_queries_per_kv_padded : tl .constexpr , # int
3436 block_table_stride : tl .constexpr , # int
3537 query_stride_0 : tl .constexpr , # int
3638 query_stride_1 : tl .constexpr , # int, should be equal to head_size
@@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
5557 query_start_len_ptr , # [num_seqs+1]
5658):
5759 seq_idx = tl .program_id (0 )
58- query_head_idx = tl .program_id (1 )
59- kv_head_idx = query_head_idx // num_queries_per_kv
60+ kv_head_idx = tl .program_id (1 )
6061
6162 if filter_by_query_len :
6263 cur_batch_in_all_start_index = tl .load (query_start_len_ptr + seq_idx )
@@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
6970 else :
7071 cur_batch_in_all_start_index = seq_idx
7172
73+ query_head_idx = kv_head_idx * num_queries_per_kv + tl .arange (
74+ 0 , num_queries_per_kv_padded )
75+
7276 query_offset = (cur_batch_in_all_start_index * query_stride_0 +
73- query_head_idx * query_stride_1 )
77+ query_head_idx [:, None ] * query_stride_1 )
78+
79+ head_mask = query_head_idx < (kv_head_idx + 1 ) * num_queries_per_kv
80+ head_mask = head_mask & (query_head_idx < num_query_heads )
7481
7582 dim_mask = tl .where (tl .arange (0 , HEAD_SIZE_PADDED ) < HEAD_SIZE , 1 ,
7683 0 ).to (tl .int1 )
7784
78- # Q : (HEAD_SIZE,)
85+ # Q : (num_queries_per_kv, HEAD_SIZE,)
7986 Q = tl .load (
80- query_ptr + query_offset + tl .arange (0 , HEAD_SIZE_PADDED ),
81- mask = dim_mask ,
87+ query_ptr + query_offset + tl .arange (0 , HEAD_SIZE_PADDED )[ None , :] ,
88+ mask = dim_mask [ None , :] & head_mask [:, None ] ,
8289 other = 0.0 ,
8390 )
8491
8592 block_table_offset = seq_idx * block_table_stride
8693
87- M = tl .full ([1 ], float ("-inf" ), dtype = tl .float32 )
88- L = tl .full ([1 ], 1.0 , dtype = tl .float32 )
89- acc = tl .zeros ([HEAD_SIZE_PADDED ], dtype = tl .float32 )
94+ M = tl .full ([num_queries_per_kv_padded ], float ("-inf" ), dtype = tl .float32 )
95+ L = tl .full ([num_queries_per_kv_padded ], 1.0 , dtype = tl .float32 )
96+ acc = tl .zeros ([num_queries_per_kv_padded , HEAD_SIZE_PADDED ],
97+ dtype = tl .float32 )
9098
9199 # sequence len for this particular sequence
92100 seq_len = tl .load (seq_lens_ptr + seq_idx )
93101
94102 # alibi slope for this head
95103 if USE_ALIBI_SLOPES :
96- alibi_slope = tl .load (alibi_slopes_ptr + query_head_idx )
104+ alibi_slope = tl .load (alibi_slopes_ptr + query_head_idx ,
105+ mask = head_mask ,
106+ other = 0.0 )
97107
98108 num_blocks = cdiv_fn (seq_len , BLOCK_SIZE )
99109
@@ -107,8 +117,8 @@ def kernel_paged_attention_2d(
107117
108118 v_offset = (physical_block_idx * stride_v_cache_0 +
109119 kv_head_idx * stride_v_cache_1 +
110- offs_d [:, None ] * stride_v_cache_2 +
111- offs_n [None , : ] * stride_v_cache_3 )
120+ offs_d [None , : ] * stride_v_cache_2 +
121+ offs_n [:, None ] * stride_v_cache_3 )
112122
113123 k_offset = (physical_block_idx * stride_k_cache_0 +
114124 kv_head_idx * stride_k_cache_1 +
@@ -126,61 +136,69 @@ def kernel_paged_attention_2d(
126136 else :
127137 K = K_load
128138
129- # V : (HEAD_SIZE, BLOCK_SIZE )
139+ # V : (BLOCK_SIZE, HEAD_SIZE )
130140 V_load = tl .load (value_cache_ptr + v_offset ,
131- mask = dim_mask [:, None ],
141+ mask = dim_mask [None , : ],
132142 other = 0.0 )
133143
134144 if V_load .dtype .is_fp8 ():
135145 V = (V_load .to (tl .float32 ) * tl .load (v_scale )).to (Q .dtype )
136146 else :
137147 V = V_load
138148
139- tmp = j * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
149+ seq_offset = j * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
140150 boundary = tl .full ([BLOCK_SIZE ], seq_len , dtype = tl .int32 )
141- mask_new = tmp < boundary
142- # S : (BLOCK_SIZE,)
143- S = tl .where (mask_new , 0.0 , float ("-inf" )).to (tl .float32 )
144- S += scale * tl .sum (K * Q [:, None ], axis = 0 )
151+ seq_mask = seq_offset [None , :] < boundary
152+
153+ # S : (num_queries_per_kv, BLOCK_SIZE,)
154+ S = tl .where (head_mask [:, None ] & seq_mask , 0.0 ,
155+ float ("-inf" )).to (tl .float32 )
156+ S += scale * tl .dot (Q , K )
157+
158+ context_len = seq_len - 1
145159
146160 if SLIDING_WINDOW > 0 :
147- S = tl .where ((seq_len - 1 - tmp ) < SLIDING_WINDOW , S , - 10000 )
161+ S = tl .where ((context_len - seq_offset ) < SLIDING_WINDOW , S ,
162+ - 10000 )
148163
149164 if USE_ALIBI_SLOPES :
150- S += alibi_slope * (tmp - seq_len + 1 )
165+ S += alibi_slope [:, None ] * (seq_offset - context_len )
151166
152167 # compute running maximum
153- # m_j : (1 ,)
154- m_j = tl .maximum (M , tl .max (S , axis = 0 ))
168+ # m_j : (num_queries_per_kv ,)
169+ m_j = tl .maximum (M , tl .max (S , axis = 1 ))
155170
156- # P : (BLOCK_SIZE,)
157- P = tl .exp (S - m_j )
171+ # P : (num_queries_per_kv, BLOCK_SIZE,)
172+ P = tl .exp (S - m_j [:, None ] )
158173
159- # l_j : (1 ,)
160- l_j = tl .sum (P , axis = 0 )
174+ # l_j : (num_queries_per_kv ,)
175+ l_j = tl .sum (P , axis = 1 )
161176
162- # alpha : (1 , )
177+ # alpha : (num_queries_per_kv , )
163178 alpha = tl .exp (M - m_j )
164179
165- # acc : (BLOCK_SIZE,)
166- acc = acc * alpha
180+ # acc : (num_queries_per_kv, BLOCK_SIZE,)
181+ acc = acc * alpha [:, None ]
167182
168183 # update constants
169184 L = L * alpha + l_j
170185 M = m_j
171186
172- # acc : (BLOCK_SIZE,)
173- acc += tl .sum ( V * P [ None , :], axis = 1 )
187+ # acc : (num_queries_per_kv, BLOCK_SIZE,)
188+ acc += tl .dot ( P . to ( V . dtype ), V )
174189
175190 # epilogue
176- acc = acc / L
191+ acc = acc / L [:, None ]
177192
178193 output_offset = (cur_batch_in_all_start_index * output_stride_0 +
179194 query_head_idx * output_stride_1 )
180195
181- tl .store (output_ptr + output_offset + tl .arange (0 , HEAD_SIZE_PADDED ),
182- acc ,
183- mask = dim_mask )
196+ tl .store (
197+ output_ptr + output_offset [:, None ] +
198+ tl .arange (0 , HEAD_SIZE_PADDED )[None , :],
199+ acc ,
200+ mask = dim_mask [None , :] & head_mask [:, None ],
201+ )
184202
185203
186204def chunked_prefill_paged_decode (
@@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
234252 block_size = value_cache .shape [3 ]
235253 num_seqs = len (seq_lens )
236254 num_query_heads = query .shape [1 ]
255+ num_kv_heads = key .shape [1 ]
237256 num_queries_per_kv = query .shape [1 ] // key .shape [1 ]
238257 head_size = query .shape [2 ]
239258
@@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
253272 key_cache = key_cache .view (target_dtype )
254273 value_cache = value_cache .view (target_dtype )
255274
275+ num_queries_per_kv_padded = max (triton .next_power_of_2 (num_queries_per_kv ),
276+ 16 )
277+
256278 kernel_paged_attention_2d [(
257279 num_seqs ,
258- num_query_heads ,
280+ num_kv_heads ,
259281 )](
260282 output_ptr = output ,
261283 query_ptr = query ,
@@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
269291 v_scale = v_scale ,
270292 num_query_heads = num_query_heads ,
271293 num_queries_per_kv = num_queries_per_kv ,
294+ num_queries_per_kv_padded = num_queries_per_kv_padded ,
272295 block_table_stride = block_table .stride (0 ),
273296 query_stride_0 = query .stride (0 ),
274297 query_stride_1 = query .stride (1 ),
0 commit comments