From 866bfa53598b939970028ce6bd0be8783f90109c Mon Sep 17 00:00:00 2001 From: yaofengchen <67218893+yao-fengchen@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:49:57 +0800 Subject: [PATCH] [ascend]feat: support kv int8 (#2736) * [ascend]feat: support kv int8 quant * update code * fix error of argument missing * update params * fix not iterable error when quant_meta is None. * Update ascend_en_get_started.md for kvcache quant * Update ascend_cn_get_started.md for kvcache quant --------- Co-authored-by: jinminxi104 --- docs/en/get_started/ascend/get_started.md | 6 ++ docs/zh_cn/get_started/ascend/get_started.md | 6 ++ lmdeploy/messages.py | 7 +- .../backends/dlinfer/ascend/op_backend.py | 88 ++++++++++++++++++- .../pytorch/backends/dlinfer/attention.py | 37 +++++++- lmdeploy/pytorch/config.py | 1 + lmdeploy/pytorch/engine/cache_engine.py | 8 +- lmdeploy/pytorch/engine/engine.py | 1 + lmdeploy/pytorch/engine/model_agent.py | 1 + .../pytorch/kernels/dlinfer/fill_kv_cache.py | 15 +++- .../pytorch/kernels/dlinfer/pagedattention.py | 33 ++++++- lmdeploy/pytorch/model_inputs.py | 6 ++ lmdeploy/pytorch/tools/make_inputs.py | 1 + 13 files changed, 199 insertions(+), 11 deletions(-) diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md index 23b86afa61..d104477ca1 100644 --- a/docs/en/get_started/ascend/get_started.md +++ b/docs/en/get_started/ascend/get_started.md @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` Please check [supported_models](../../supported_models/supported_models.md) before use this feature. + +### int8 KV-cache Quantization + +Ascend backend has supported offline int8 KV-cache Quantization on eager mode. + +Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details. diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md index b137c458be..9f0a7b1f90 100644 --- a/docs/zh_cn/get_started/ascend/get_started.md +++ b/docs/zh_cn/get_started/ascend/get_started.md @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 + +### int8 KV-cache 量化 + +昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 + +详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。 diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90823598ea..2336d10752 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -293,8 +293,11 @@ def __post_init__(self): assert self.device_type in [ 'cuda', 'ascend', 'maca' ], (f'invalid device_type: {self.device_type}') - if self.quant_policy > 0 and self.device_type != 'cuda': - assert False, 'kv cache quantization only works for CUDA.' + if self.quant_policy > 0 and self.device_type not in [ + 'cuda', 'ascend' + ]: + assert False, \ + 'kv cache quantization only works for CUDA and ASCEND.' class ResponseType(enum.Enum): diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index b6f544510b..588558f0d5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +import itertools +import os +import re +from pathlib import Path +from typing import Dict, Tuple import torch @@ -11,6 +15,71 @@ logger = get_logger('lmdeploy') +class AscendKVQuantMeta: + has_set_value: bool = False + quant_meta: Dict = {} + + @classmethod + def set_value(cls, device: str, dtype: torch.dtype, record_file: str, + total_layers: int): + with open(record_file, 'r') as file: + data = file.read() + scale_offset_pairs = re.findall( + r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data) + scale_offset_pairs = [(float(scale), float(offset)) + for scale, offset in scale_offset_pairs] + k_scales, v_scales, kv_scales = [], [], [] + k_zeros, v_zeros, kv_zeros = [], [], [] + if len(scale_offset_pairs) == total_layers: + for scale, offset in scale_offset_pairs: + k_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale, scale], device=device, dtype=dtype)) + k_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset, offset], device=device, dtype=dtype)) + elif len(scale_offset_pairs) == total_layers * 2: + for i in range(total_layers): + scale_k, offset_k = scale_offset_pairs[2 * i] + scale_v, offset_v = scale_offset_pairs[2 * i + 1] + k_scales.append( + torch.tensor([scale_k], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale_v], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale_k, scale_v], + device=device, + dtype=dtype)) + k_zeros.append( + torch.tensor([offset_k], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset_v], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset_k, offset_v], + device=device, + dtype=dtype)) + else: + raise ValueError( + f'num of scale_offset_pairs({len(scale_offset_pairs)}) ' + f'must match num of total_layers({total_layers})') + + cls.quant_meta.update({ + 'k_scales': itertools.cycle(k_scales), + 'k_zeros': itertools.cycle(k_zeros), + 'v_scales': itertools.cycle(v_scales), + 'v_zeros': itertools.cycle(v_zeros), + 'kv_scales': itertools.cycle(kv_scales), + 'kv_zeros': itertools.cycle(kv_zeros) + }) + cls.has_set_value = True + + class AscendOpsBackend(DlinferOpsBackend): """ascend layer backend.""" enable_graph = False @@ -164,6 +233,21 @@ def get_total_slots(): .repeat_interleave(step_context.q_seqlens, 0) kv_seqlens = kv_seqlens_cpu + if not cls.enable_graph and step_context.kv_quant_policy == 8: + record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') + assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE' + path = Path(record_file) + is_path = path.is_absolute() or path.is_relative_to('/') + exists = path.exists() + if not (is_path and exists): + raise ValueError( + 'please specify valid ASCEND_QUANT_RECORD_FILE') + if not AscendKVQuantMeta.has_set_value: + total_layers = len(step_context.kv_caches) + AscendKVQuantMeta.set_value(step_context.block_offsets.device, + step_context.model_config.dtype, + record_file, total_layers) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, @@ -177,6 +261,8 @@ def get_total_slots(): is_unpaged_prefill=is_unpaged_prefill, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, + quant_policy=step_context.kv_quant_policy, + quant_meta=AscendKVQuantMeta.quant_meta, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c9130..d1b5b619d0 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence from torch import Tensor @@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata): is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 + quant_meta: Dict = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -74,10 +75,37 @@ def forward( is_unpaged_prefill = attn_metadata.is_unpaged_prefill max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len + quant_bits = attn_metadata.quant_policy + if attn_metadata.quant_meta is not None: + k_scales_zeros = [ + next(attn_metadata.quant_meta['k_scales']), + next(attn_metadata.quant_meta['k_zeros']) + ] if 'k_scales' in attn_metadata.quant_meta else [] + v_scales_zeros = [ + next(attn_metadata.quant_meta['v_scales']), + next(attn_metadata.quant_meta['v_zeros']) + ] if 'v_scales' in attn_metadata.quant_meta else [] + kv_scales = next( + attn_metadata.quant_meta['kv_scales'] + ) if 'kv_scales' in attn_metadata.quant_meta else None + kv_zeros = next( + attn_metadata.quant_meta['kv_zeros'] + ) if 'kv_zeros' in attn_metadata.quant_meta else None + else: + k_scales_zeros = [] + v_scales_zeros = [] + kv_scales = None + kv_zeros = None # fill kv cache - k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, - kv_start_indices) + k_cache, v_cache = self.fill_kv_cache(key, + value, + k_cache, + v_cache, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) if inplace: attn_output = query[..., :self.v_head_size] @@ -103,6 +131,9 @@ def forward( block_size=block_size, attn_mask=attn_mask, is_unpaged_prefill=is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) return attn_output diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index c350f4b4cf..da2ac35c0e 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -77,6 +77,7 @@ class CacheConfig: max_prefill_token_num: int = 4096 enable_prefix_caching: bool = False quant_policy: Literal[0, 4, 8] = 0 + device_type: str = 'cuda' def __post_init__(self): """post init.""" diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index e393adeed3..ffaeafa90e 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -44,7 +44,13 @@ def __init__( self.num_layers = model_config.num_layers self.kv_cache_dtype = model_config.dtype if cache_config.quant_policy > 0: - self.kv_cache_dtype = torch.uint8 + if self.cache_config.device_type in ['cuda']: + self.kv_cache_dtype = torch.uint8 + elif self.cache_config.device_type in ['ascend', 'npu']: + self.kv_cache_dtype = torch.int8 + else: + raise ValueError( + f'unsupported device_type {self.cache_config.device_type}') # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b7a803a7a7..715291a901 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -130,6 +130,7 @@ def __init__(self, max_prefill_token_num=engine_config.max_prefill_token_num, enable_prefix_caching=engine_config.enable_prefix_caching, quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, ) if not os.path.exists(model_path): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 59d77f264a..8e47df70b5 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -139,6 +139,7 @@ def model_forward( ctx_mgr = model.ctx_mgr context = ctx_mgr.build_context( inputs=inputs, + model_config=cache_engine.model_config, world_size=world_size, kv_caches=cache_engine.gpu_cache, kv_quant_policy=cache_engine.cache_config.quant_policy, diff --git a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py index fb2eee9d41..63564d7ed8 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + import dlinfer.ops as ext_ops from torch import Tensor @@ -9,7 +11,16 @@ def fill_kv_cache( key_caches: Tensor, value_caches: Tensor, kv_start_indices: Tensor, + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int = 0, ): """fill key/value state to cache for paged attention.""" - return ext_ops.fill_kv_cache(key_states, value_states, key_caches, - value_caches, kv_start_indices) + return ext_ops.fill_kv_cache(key_states, + value_states, + key_caches, + value_caches, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 47bcb0cfff..ded85d476d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -19,6 +19,9 @@ def prefill_attention( block_size: int, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: num_q_heads = query_states.shape[1] num_kv_heads = value_states.shape[1] @@ -53,11 +56,25 @@ def prefill_attention( num_kv_heads, attn_mask, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - max_kv_seq_len, block_offsets, block_size): +def paged_token_attention( + q, + k_cache, + v_cache, + attn_output, + kv_seq_len, + max_kv_seq_len, + block_offsets, + block_size, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], +): num_q_heads, q_head_dim = q.shape[1:3] num_kv_heads = k_cache.shape[-1] // q_head_dim return ext_ops.paged_decode_attention( @@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, num_q_heads, num_kv_heads, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) @@ -91,6 +111,9 @@ def paged_attention_fwd( block_size: int, attn_mask: Sequence[Optional[Tensor]] = (), is_unpaged_prefill: Optional[bool] = None, + kv_scales: Optional[Tensor] = None, + kv_zeros: Optional[Tensor] = None, + quant_bits: Optional[int] = 0, ): if not is_decoding: return prefill_attention( @@ -108,6 +131,9 @@ def paged_attention_fwd( block_size, attn_mask, is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) else: return paged_token_attention( @@ -119,4 +145,7 @@ def paged_attention_fwd( max_kv_seq_len, block_offsets, block_size, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 669625d43d..d95aa6fafc 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -6,6 +6,7 @@ import torch from lmdeploy.pytorch.backends import get_backend +from lmdeploy.pytorch.config import ModelConfig @dataclass @@ -198,6 +199,7 @@ class StepContext: dataclass provide these infos and tools. """ input_ids: torch.LongTensor + model_config: ModelConfig block_offsets: torch.LongTensor position_ids: torch.LongTensor attention_mask: torch.LongTensor @@ -224,6 +226,7 @@ class StepContext: def new( cls, inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -273,6 +276,7 @@ def new( ret = StepContext( input_ids=inputs.input_ids, + model_config=model_config, block_offsets=inputs.block_offsets, position_ids=position_ids, input_embeddings=input_embeddings, @@ -318,6 +322,7 @@ def __init__(self): @staticmethod def build_context( inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -325,6 +330,7 @@ def build_context( """build context.""" return StepContext.new( inputs, + model_config, world_size, kv_caches, kv_quant_policy, diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py index f2d23830b7..053e7d0918 100644 --- a/lmdeploy/pytorch/tools/make_inputs.py +++ b/lmdeploy/pytorch/tools/make_inputs.py @@ -135,6 +135,7 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets): return StepContext.new( inputs=model_inputs, + model_config=model_config, world_size=world_size, kv_caches=kv_caches, )