Skip to content

Commit 5820107

Browse files
authored
support new backend cambricon (#3002)
* [dlinfer]add camb support * [camb] fix multiple of 8, exp raise core dump * [camb] fix multiple of 8, exp raise core dump * [camb] format * [camb]pow of 2 better * [camb]rm local_adapterids * [camb]modify graph runner * [camb]mock graph runner * [camb]add requirements.txt * [camb]post init set block_size to 16 * lint
1 parent 39af9c8 commit 5820107

File tree

11 files changed

+182
-4
lines changed

11 files changed

+182
-4
lines changed

lmdeploy/cli/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def calib_search_scale(parser):
377377
@staticmethod
378378
def device(parser,
379379
default: str = 'cuda',
380-
choices: List[str] = ['cuda', 'ascend', 'maca']):
380+
choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):
381381
"""Add argument device to parser."""
382382

383383
return parser.add_argument('--device',

lmdeploy/messages.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from pydantic.dataclasses import dataclass as pydantic_dataclass
88

99
from .tokenizer import Tokenizer
10+
from .utils import get_logger
11+
12+
logger = get_logger('lmdeploy')
1013

1114
LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
1215
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
@@ -297,13 +300,18 @@ def __post_init__(self):
297300
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
298301
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
299302
assert self.device_type in [
300-
'cuda', 'ascend', 'maca'
303+
'cuda', 'ascend', 'maca', 'camb'
301304
], (f'invalid device_type: {self.device_type}')
302305
if self.quant_policy > 0 and self.device_type not in [
303306
'cuda', 'ascend'
304307
]:
305308
assert False, \
306309
'kv cache quantization only works for CUDA and ASCEND.'
310+
if self.device_type == 'camb' and self.block_size != 16:
311+
self.block_size = 16
312+
logger.warning(
313+
'Currently, camb device requires block size to be 16, \
314+
setting block size to 16')
307315

308316

309317
class ResponseType(enum.Enum):
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .ascend import AscendOpsBackend # noqa: F401
3+
from .camb import CambOpsBackend # noqa: F401
34
from .maca import MacaOpsBackend # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .op_backend import CambOpsBackend # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Tuple
3+
4+
import torch
5+
6+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
7+
from lmdeploy.utils import get_logger
8+
9+
from ..op_backend import DlinferOpsBackend
10+
11+
logger = get_logger('lmdeploy')
12+
13+
14+
class CambOpsBackend(DlinferOpsBackend):
15+
"""camb layer backend."""
16+
total_slots = None
17+
18+
@staticmethod
19+
def get_name() -> str:
20+
"""backend name."""
21+
return 'camb'
22+
23+
@staticmethod
24+
def get_k_block_shape(
25+
block_size: int,
26+
num_heads: int,
27+
head_size: int,
28+
dtype: torch.dtype,
29+
) -> Tuple[int, ...]:
30+
return (
31+
num_heads,
32+
block_size,
33+
head_size,
34+
)
35+
36+
@staticmethod
37+
def get_v_block_shape(
38+
block_size: int,
39+
num_heads: int,
40+
head_size: int,
41+
dtype: torch.dtype,
42+
) -> Tuple[int, ...]:
43+
return (
44+
num_heads,
45+
block_size,
46+
head_size,
47+
)
48+
49+
@classmethod
50+
def update_step_context(cls, step_context):
51+
"""update step context."""
52+
53+
def get_total_slots():
54+
if cls.total_slots is None:
55+
cls.total_slots = torch.arange(
56+
block_num * block_size,
57+
dtype=torch.int32,
58+
device=step_context.block_offsets.device)
59+
cls.total_slots = cls.total_slots.view(block_num, block_size)
60+
return cls.total_slots
61+
62+
kv_start_indices = []
63+
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape
64+
65+
is_unpaged_prefill = False
66+
q_start_loc = step_context.q_start_loc
67+
q_seqlens = step_context.q_seqlens
68+
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
69+
block_offsets = step_context.block_offsets.to(torch.int32)
70+
max_q_seq_len = torch.max(q_seqlens).cpu().item()
71+
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()
72+
73+
cu_seqlens = torch.cat(
74+
(q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
75+
cu_seq_lens_kv = None
76+
77+
q_seqlens_list = step_context.q_seqlens.tolist()
78+
kv_seqlens_list = step_context.kv_seqlens.tolist()
79+
if not step_context.is_decoding:
80+
is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
81+
# get kv_indices
82+
for i in range(q_start_loc.size(0)):
83+
q_seq_len = q_seqlens_list[i]
84+
kv_seq_len = kv_seqlens_list[i]
85+
# collect kv start indices.
86+
history_length = kv_seq_len - q_seq_len
87+
total_slots = get_total_slots()
88+
slot_tables = total_slots[block_offsets[i]].view(-1)
89+
slots = slot_tables[history_length:kv_seq_len]
90+
kv_start_indices.append(slots)
91+
kv_start_indices = torch.cat(kv_start_indices)
92+
if not is_unpaged_prefill:
93+
cu_seq_lens_kv = torch.cat(
94+
(torch.tensor([0], device=kv_seqlens.device),
95+
kv_seqlens.cumsum(0))).int()
96+
else:
97+
# collect kv_start_indices without using a for-loop,
98+
# (fill kv-cache for just ONE token during the decoding phase)
99+
idx = (step_context.kv_seqlens - 1) % block_size
100+
block_num = (step_context.kv_seqlens - 1) // block_size
101+
last_block = block_offsets.gather( # dtype of gather must be int64
102+
1, block_num.view(-1, 1)).view(-1)
103+
kv_start_indices = (last_block * block_size + idx).to(torch.int32)
104+
105+
attn_meta_cls = cls.get_attention_metadata_cls()
106+
attn_metadata = attn_meta_cls(
107+
step_context.is_decoding,
108+
block_offsets,
109+
q_start_loc=cu_seqlens,
110+
cu_seq_lens_kv=cu_seq_lens_kv,
111+
q_seqlens=q_seqlens,
112+
kv_seqlens=kv_seqlens,
113+
kv_start_indices=kv_start_indices,
114+
block_size=block_size,
115+
attention_mask=None,
116+
is_unpaged_prefill=is_unpaged_prefill,
117+
max_q_seq_len=max_q_seq_len,
118+
max_kv_seq_len=max_kv_seq_len,
119+
)
120+
121+
step_context.attn_metadata = attn_metadata
122+
return step_context
123+
124+
@staticmethod
125+
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
126+
cache_config: CacheConfig,
127+
backend_config: BackendConfig,
128+
device: torch.device):
129+
"""build graph runner."""
130+
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
131+
return CUDAGraphRunner(model, model_config, cache_config,
132+
backend_config, device)

lmdeploy/pytorch/backends/selector.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,8 @@ def get_backend():
1818
if device_type == 'maca':
1919
from .dlinfer import MacaOpsBackend
2020
return MacaOpsBackend
21+
if device_type == 'camb':
22+
from .dlinfer import CambOpsBackend
23+
return CambOpsBackend
2124
else:
2225
raise RuntimeError(f'Unsupported device type: {device_type}')

lmdeploy/pytorch/check_env/deeplink.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
'ascend',
66
'npu',
77
'maca',
8+
'camb',
89
]
910

1011

lmdeploy/pytorch/models/module_map.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
MODULE_MAP = dict()
77
ASCEND_MODULE_MAP = dict()
88
MACA_MODULE_MAP = dict()
9+
CAMB_MODULE_MAP = dict()
910

1011
DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP,
11-
maca=MACA_MODULE_MAP)
12+
maca=MACA_MODULE_MAP,
13+
camb=CAMB_MODULE_MAP)
1214

1315
# llama
1416
MODULE_MAP.update({

lmdeploy/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str):
332332
Args:
333333
device_type (str): the type of device
334334
"""
335-
assert device_type in ['cuda', 'ascend', 'maca']
335+
assert device_type in ['cuda', 'ascend', 'maca', 'camb']
336336
if device_type == 'cuda':
337337
max_batch_size_map = {
338338
'a100': 256,
@@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str):
352352
return 16
353353
elif device_type == 'maca':
354354
return 128
355+
elif device_type == 'camb':
356+
return 128
355357

356358

357359
def is_bf16_supported(device_type: str = 'cuda'):
@@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'):
387389
# return False
388390
elif device_type == 'maca':
389391
return True
392+
elif device_type == 'camb':
393+
return True
390394
else:
391395
return False

requirements/runtime_camb.txt

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
accelerate==1.2.0
2+
einops
3+
fastapi
4+
fire
5+
mmengine-lite
6+
numpy<2.0.0
7+
openai
8+
outlines<0.1.0
9+
peft<=0.11.1
10+
pillow
11+
protobuf
12+
pydantic>2.0.0
13+
pynvml
14+
safetensors
15+
sentencepiece
16+
shortuuid
17+
tiktoken
18+
torch==2.4.0
19+
torchvision<=0.19.0,>=0.15.0
20+
transformers
21+
uvicorn

requirements_camb.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-r requirements/build.txt
2+
-r requirements/runtime_camb.txt
3+
-r requirements/lite.txt
4+
-r requirements/serve.txt

0 commit comments

Comments
 (0)