Skip to content

Commit

Permalink
[ascend]feat: support kv int8 (#2736)
Browse files Browse the repository at this point in the history
* [ascend]feat: support kv int8 quant

* update code

* fix error of argument missing

* update params

* fix not iterable error when quant_meta is None.

* Update ascend_en_get_started.md for kvcache quant

* Update ascend_cn_get_started.md for kvcache quant

---------

Co-authored-by: jinminxi104 <[email protected]>
  • Loading branch information
yao-fengchen and jinminxi104 authored Dec 6, 2024
1 parent 9bfdeae commit 866bfa5
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/en/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

Please check [supported_models](../../supported_models/supported_models.md) before use this feature.

### int8 KV-cache Quantization

Ascend backend has supported offline int8 KV-cache Quantization on eager mode.

Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.
6 changes: 6 additions & 0 deletions docs/zh_cn/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)

### int8 KV-cache 量化

昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。

详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)
7 changes: 5 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,11 @@ def __post_init__(self):
assert self.device_type in [
'cuda', 'ascend', 'maca'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type != 'cuda':
assert False, 'kv cache quantization only works for CUDA.'
if self.quant_policy > 0 and self.device_type not in [
'cuda', 'ascend'
]:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'


class ResponseType(enum.Enum):
Expand Down
88 changes: 87 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import itertools
import os
import re
from pathlib import Path
from typing import Dict, Tuple

import torch

Expand All @@ -11,6 +15,71 @@
logger = get_logger('lmdeploy')


class AscendKVQuantMeta:
has_set_value: bool = False
quant_meta: Dict = {}

@classmethod
def set_value(cls, device: str, dtype: torch.dtype, record_file: str,
total_layers: int):
with open(record_file, 'r') as file:
data = file.read()
scale_offset_pairs = re.findall(
r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data)
scale_offset_pairs = [(float(scale), float(offset))
for scale, offset in scale_offset_pairs]
k_scales, v_scales, kv_scales = [], [], []
k_zeros, v_zeros, kv_zeros = [], [], []
if len(scale_offset_pairs) == total_layers:
for scale, offset in scale_offset_pairs:
k_scales.append(
torch.tensor([scale], device=device, dtype=dtype))
v_scales.append(
torch.tensor([scale], device=device, dtype=dtype))
kv_scales.append(
torch.tensor([scale, scale], device=device, dtype=dtype))
k_zeros.append(
torch.tensor([offset], device=device, dtype=dtype))
v_zeros.append(
torch.tensor([offset], device=device, dtype=dtype))
kv_zeros.append(
torch.tensor([offset, offset], device=device, dtype=dtype))
elif len(scale_offset_pairs) == total_layers * 2:
for i in range(total_layers):
scale_k, offset_k = scale_offset_pairs[2 * i]
scale_v, offset_v = scale_offset_pairs[2 * i + 1]
k_scales.append(
torch.tensor([scale_k], device=device, dtype=dtype))
v_scales.append(
torch.tensor([scale_v], device=device, dtype=dtype))
kv_scales.append(
torch.tensor([scale_k, scale_v],
device=device,
dtype=dtype))
k_zeros.append(
torch.tensor([offset_k], device=device, dtype=dtype))
v_zeros.append(
torch.tensor([offset_v], device=device, dtype=dtype))
kv_zeros.append(
torch.tensor([offset_k, offset_v],
device=device,
dtype=dtype))
else:
raise ValueError(
f'num of scale_offset_pairs({len(scale_offset_pairs)}) '
f'must match num of total_layers({total_layers})')

cls.quant_meta.update({
'k_scales': itertools.cycle(k_scales),
'k_zeros': itertools.cycle(k_zeros),
'v_scales': itertools.cycle(v_scales),
'v_zeros': itertools.cycle(v_zeros),
'kv_scales': itertools.cycle(kv_scales),
'kv_zeros': itertools.cycle(kv_zeros)
})
cls.has_set_value = True


class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""
enable_graph = False
Expand Down Expand Up @@ -164,6 +233,21 @@ def get_total_slots():
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu

if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
path = Path(record_file)
is_path = path.is_absolute() or path.is_relative_to('/')
exists = path.exists()
if not (is_path and exists):
raise ValueError(
'please specify valid ASCEND_QUANT_RECORD_FILE')
if not AscendKVQuantMeta.has_set_value:
total_layers = len(step_context.kv_caches)
AscendKVQuantMeta.set_value(step_context.block_offsets.device,
step_context.model_config.dtype,
record_file, total_layers)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
Expand All @@ -177,6 +261,8 @@ def get_total_slots():
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
)

step_context.attn_metadata = attn_metadata
Expand Down
37 changes: 34 additions & 3 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence

from torch import Tensor

Expand All @@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
is_unpaged_prefill: Optional[bool] = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
quant_meta: Dict = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down Expand Up @@ -74,10 +75,37 @@ def forward(
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
quant_bits = attn_metadata.quant_policy
if attn_metadata.quant_meta is not None:
k_scales_zeros = [
next(attn_metadata.quant_meta['k_scales']),
next(attn_metadata.quant_meta['k_zeros'])
] if 'k_scales' in attn_metadata.quant_meta else []
v_scales_zeros = [
next(attn_metadata.quant_meta['v_scales']),
next(attn_metadata.quant_meta['v_zeros'])
] if 'v_scales' in attn_metadata.quant_meta else []
kv_scales = next(
attn_metadata.quant_meta['kv_scales']
) if 'kv_scales' in attn_metadata.quant_meta else None
kv_zeros = next(
attn_metadata.quant_meta['kv_zeros']
) if 'kv_zeros' in attn_metadata.quant_meta else None
else:
k_scales_zeros = []
v_scales_zeros = []
kv_scales = None
kv_zeros = None

# fill kv cache
k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
kv_start_indices)
k_cache, v_cache = self.fill_kv_cache(key,
value,
k_cache,
v_cache,
kv_start_indices,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_bits=quant_bits)

if inplace:
attn_output = query[..., :self.v_head_size]
Expand All @@ -103,6 +131,9 @@ def forward(
block_size=block_size,
attn_mask=attn_mask,
is_unpaged_prefill=is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)

return attn_output
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CacheConfig:
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
device_type: str = 'cuda'

def __post_init__(self):
"""post init."""
Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def __init__(
self.num_layers = model_config.num_layers
self.kv_cache_dtype = model_config.dtype
if cache_config.quant_policy > 0:
self.kv_cache_dtype = torch.uint8
if self.cache_config.device_type in ['cuda']:
self.kv_cache_dtype = torch.uint8
elif self.cache_config.device_type in ['ascend', 'npu']:
self.kv_cache_dtype = torch.int8
else:
raise ValueError(
f'unsupported device_type {self.cache_config.device_type}')

# Initialize the cache.
self.local_gpu_cache = self.allocate_gpu_cache()
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self,
max_prefill_token_num=engine_config.max_prefill_token_num,
enable_prefix_caching=engine_config.enable_prefix_caching,
quant_policy=engine_config.quant_policy,
device_type=engine_config.device_type,
)

if not os.path.exists(model_path):
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def model_forward(
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
model_config=cache_engine.model_config,
world_size=world_size,
kv_caches=cache_engine.gpu_cache,
kv_quant_policy=cache_engine.cache_config.quant_policy,
Expand Down
15 changes: 13 additions & 2 deletions lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import dlinfer.ops as ext_ops
from torch import Tensor

Expand All @@ -9,7 +11,16 @@ def fill_kv_cache(
key_caches: Tensor,
value_caches: Tensor,
kv_start_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int = 0,
):
"""fill key/value state to cache for paged attention."""
return ext_ops.fill_kv_cache(key_states, value_states, key_caches,
value_caches, kv_start_indices)
return ext_ops.fill_kv_cache(key_states,
value_states,
key_caches,
value_caches,
kv_start_indices,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_bits=quant_bits)
33 changes: 31 additions & 2 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def prefill_attention(
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
Expand Down Expand Up @@ -53,11 +56,25 @@ def prefill_attention(
num_kv_heads,
attn_mask,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
max_kv_seq_len, block_offsets, block_size):
def paged_token_attention(
q,
k_cache,
v_cache,
attn_output,
kv_seq_len,
max_kv_seq_len,
block_offsets,
block_size,
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
):
num_q_heads, q_head_dim = q.shape[1:3]
num_kv_heads = k_cache.shape[-1] // q_head_dim
return ext_ops.paged_decode_attention(
Expand All @@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)


Expand All @@ -91,6 +111,9 @@ def paged_attention_fwd(
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
is_unpaged_prefill: Optional[bool] = None,
kv_scales: Optional[Tensor] = None,
kv_zeros: Optional[Tensor] = None,
quant_bits: Optional[int] = 0,
):
if not is_decoding:
return prefill_attention(
Expand All @@ -108,6 +131,9 @@ def paged_attention_fwd(
block_size,
attn_mask,
is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)
else:
return paged_token_attention(
Expand All @@ -119,4 +145,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)
Loading

0 comments on commit 866bfa5

Please sign in to comment.