Skip to content

Commit 39af9c8

Browse files
authored
[dlinfer]change llm op interface of paged_prefill_attention. (#2977)
* [dlinfer]modify interface to support camb multi-batch-conv * [dlinfer]change order for paged_prefill
1 parent f066384 commit 39af9c8

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

lmdeploy/pytorch/backends/dlinfer/attention.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
1616
max_q_seq_len: int = 1
1717
max_kv_seq_len: int = 1
1818
quant_meta: Dict = None
19+
cu_seq_lens_kv: Optional[Tensor] = None
1920

2021

2122
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
@@ -79,6 +80,8 @@ def forward(
7980
max_q_seq_len = attn_metadata.max_q_seq_len
8081
max_kv_seq_len = attn_metadata.max_kv_seq_len
8182
quant_bits = attn_metadata.quant_policy
83+
cu_seq_lens_kv = attn_metadata.cu_seq_lens_kv
84+
8285
if attn_metadata.quant_meta is not None:
8386
k_scales_zeros = [
8487
next(attn_metadata.quant_meta['k_scales']),
@@ -128,6 +131,7 @@ def forward(
128131
q_start_loc=q_start_loc,
129132
q_seqlens=q_seqlens,
130133
kv_seqlens=kv_seqlens,
134+
cu_seq_lens_kv=cu_seq_lens_kv,
131135
max_q_seq_len=max_q_seq_len,
132136
max_kv_seq_len=max_kv_seq_len,
133137
is_decoding=is_decoding,

lmdeploy/pytorch/kernels/dlinfer/pagedattention.py

+7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def prefill_attention(
1515
q_start_loc: Tensor,
1616
q_seq_len: Tensor,
1717
kv_seq_len: Tensor,
18+
cu_seq_lens_kv: Tensor,
1819
max_q_seq_len: int,
20+
max_kv_seq_len: int,
1921
block_size: int,
2022
attn_mask: Sequence[Optional[Tensor]],
2123
is_unpaged_prefill: Optional[bool],
@@ -51,7 +53,9 @@ def prefill_attention(
5153
q_start_loc,
5254
q_seq_len,
5355
kv_seq_len,
56+
cu_seq_lens_kv,
5457
max_q_seq_len,
58+
max_kv_seq_len,
5559
num_q_heads,
5660
num_kv_heads,
5761
attn_mask,
@@ -105,6 +109,7 @@ def paged_attention_fwd(
105109
q_start_loc: Tensor,
106110
q_seqlens: Tensor,
107111
kv_seqlens: Tensor,
112+
cu_seq_lens_kv: Tensor,
108113
max_q_seq_len: int,
109114
max_kv_seq_len: int,
110115
is_decoding: bool,
@@ -127,7 +132,9 @@ def paged_attention_fwd(
127132
q_start_loc,
128133
q_seqlens,
129134
kv_seqlens,
135+
cu_seq_lens_kv,
130136
max_q_seq_len,
137+
max_kv_seq_len,
131138
block_size,
132139
attn_mask,
133140
is_unpaged_prefill,

0 commit comments

Comments
 (0)