|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig |
| 7 | +from lmdeploy.utils import get_logger |
| 8 | + |
| 9 | +from ..op_backend import DlinferOpsBackend |
| 10 | + |
| 11 | +logger = get_logger('lmdeploy') |
| 12 | + |
| 13 | + |
| 14 | +class CambOpsBackend(DlinferOpsBackend): |
| 15 | + """camb layer backend.""" |
| 16 | + total_slots = None |
| 17 | + |
| 18 | + @staticmethod |
| 19 | + def get_name() -> str: |
| 20 | + """backend name.""" |
| 21 | + return 'camb' |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def get_k_block_shape( |
| 25 | + block_size: int, |
| 26 | + num_heads: int, |
| 27 | + head_size: int, |
| 28 | + dtype: torch.dtype, |
| 29 | + ) -> Tuple[int, ...]: |
| 30 | + return ( |
| 31 | + num_heads, |
| 32 | + block_size, |
| 33 | + head_size, |
| 34 | + ) |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def get_v_block_shape( |
| 38 | + block_size: int, |
| 39 | + num_heads: int, |
| 40 | + head_size: int, |
| 41 | + dtype: torch.dtype, |
| 42 | + ) -> Tuple[int, ...]: |
| 43 | + return ( |
| 44 | + num_heads, |
| 45 | + block_size, |
| 46 | + head_size, |
| 47 | + ) |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def update_step_context(cls, step_context): |
| 51 | + """update step context.""" |
| 52 | + |
| 53 | + def get_total_slots(): |
| 54 | + if cls.total_slots is None: |
| 55 | + cls.total_slots = torch.arange( |
| 56 | + block_num * block_size, |
| 57 | + dtype=torch.int32, |
| 58 | + device=step_context.block_offsets.device) |
| 59 | + cls.total_slots = cls.total_slots.view(block_num, block_size) |
| 60 | + return cls.total_slots |
| 61 | + |
| 62 | + kv_start_indices = [] |
| 63 | + block_num, _, block_size, _ = step_context.kv_caches[0][0].shape |
| 64 | + |
| 65 | + is_unpaged_prefill = False |
| 66 | + q_start_loc = step_context.q_start_loc |
| 67 | + q_seqlens = step_context.q_seqlens |
| 68 | + kv_seqlens = step_context.kv_seqlens.to(torch.int32) |
| 69 | + block_offsets = step_context.block_offsets.to(torch.int32) |
| 70 | + max_q_seq_len = torch.max(q_seqlens).cpu().item() |
| 71 | + max_kv_seq_len = torch.max(kv_seqlens).cpu().item() |
| 72 | + |
| 73 | + cu_seqlens = torch.cat( |
| 74 | + (q_start_loc, q_seqlens.sum().unsqueeze(0))).int() |
| 75 | + cu_seq_lens_kv = None |
| 76 | + |
| 77 | + q_seqlens_list = step_context.q_seqlens.tolist() |
| 78 | + kv_seqlens_list = step_context.kv_seqlens.tolist() |
| 79 | + if not step_context.is_decoding: |
| 80 | + is_unpaged_prefill = q_seqlens_list == kv_seqlens_list |
| 81 | + # get kv_indices |
| 82 | + for i in range(q_start_loc.size(0)): |
| 83 | + q_seq_len = q_seqlens_list[i] |
| 84 | + kv_seq_len = kv_seqlens_list[i] |
| 85 | + # collect kv start indices. |
| 86 | + history_length = kv_seq_len - q_seq_len |
| 87 | + total_slots = get_total_slots() |
| 88 | + slot_tables = total_slots[block_offsets[i]].view(-1) |
| 89 | + slots = slot_tables[history_length:kv_seq_len] |
| 90 | + kv_start_indices.append(slots) |
| 91 | + kv_start_indices = torch.cat(kv_start_indices) |
| 92 | + if not is_unpaged_prefill: |
| 93 | + cu_seq_lens_kv = torch.cat( |
| 94 | + (torch.tensor([0], device=kv_seqlens.device), |
| 95 | + kv_seqlens.cumsum(0))).int() |
| 96 | + else: |
| 97 | + # collect kv_start_indices without using a for-loop, |
| 98 | + # (fill kv-cache for just ONE token during the decoding phase) |
| 99 | + idx = (step_context.kv_seqlens - 1) % block_size |
| 100 | + block_num = (step_context.kv_seqlens - 1) // block_size |
| 101 | + last_block = block_offsets.gather( # dtype of gather must be int64 |
| 102 | + 1, block_num.view(-1, 1)).view(-1) |
| 103 | + kv_start_indices = (last_block * block_size + idx).to(torch.int32) |
| 104 | + |
| 105 | + attn_meta_cls = cls.get_attention_metadata_cls() |
| 106 | + attn_metadata = attn_meta_cls( |
| 107 | + step_context.is_decoding, |
| 108 | + block_offsets, |
| 109 | + q_start_loc=cu_seqlens, |
| 110 | + cu_seq_lens_kv=cu_seq_lens_kv, |
| 111 | + q_seqlens=q_seqlens, |
| 112 | + kv_seqlens=kv_seqlens, |
| 113 | + kv_start_indices=kv_start_indices, |
| 114 | + block_size=block_size, |
| 115 | + attention_mask=None, |
| 116 | + is_unpaged_prefill=is_unpaged_prefill, |
| 117 | + max_q_seq_len=max_q_seq_len, |
| 118 | + max_kv_seq_len=max_kv_seq_len, |
| 119 | + ) |
| 120 | + |
| 121 | + step_context.attn_metadata = attn_metadata |
| 122 | + return step_context |
| 123 | + |
| 124 | + @staticmethod |
| 125 | + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, |
| 126 | + cache_config: CacheConfig, |
| 127 | + backend_config: BackendConfig, |
| 128 | + device: torch.device): |
| 129 | + """build graph runner.""" |
| 130 | + from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner |
| 131 | + return CUDAGraphRunner(model, model_config, cache_config, |
| 132 | + backend_config, device) |
0 commit comments